Files
backend/session.go
NanamiAdmin e400cc1869 feat(session): add session list API
- Move token-related functions from auth.go to new session.go
- Add session tracking with expiration and cleanup
- Implement session list API endpoint
- Update login/logout handlers to use session system
- Add hourly cleanup of expired tokens and sessions
2026-03-05 18:24:19 +08:00

320 lines
7.2 KiB
Go

package main
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"net/http"
"strconv"
"strings"
"super-frpc/postLog"
"sync"
"time"
)
type TokenInfo struct {
Token string
CreatedAt time.Time
UserID int
}
var (
tokenMap = make(map[int]*TokenInfo)
tokenMux sync.RWMutex
tokenTTL = time.Hour
)
type Session struct {
ID string
UserID int
Username string
ExpireAt time.Time
}
var (
sessionMap = make(map[string]*Session)
sessionMux sync.RWMutex
sessionTTL = time.Hour
)
func GenerateToken(userID int) (string, error) {
randomBytes := make([]byte, 32)
_, err := rand.Read(randomBytes)
if err != nil {
return "", fmt.Errorf("Failed to generate random bytes: %w", err)
}
hash := sha256.Sum256(append(randomBytes, []byte(fmt.Sprintf("%d", userID))...))
token := base64.URLEncoding.EncodeToString(hash[:])
tokenMux.Lock()
defer tokenMux.Unlock()
tokenMap[userID] = &TokenInfo{
Token: token,
CreatedAt: time.Now(),
UserID: userID,
}
postLog.Debug(fmt.Sprintf("[GenerateToken] Generated token for userID %d: %s", userID, token))
return token, nil
}
func ValidateToken(userID int, token string) error {
tokenMux.RLock()
defer tokenMux.RUnlock()
tokenInfo, exists := tokenMap[userID]
if !exists {
return fmt.Errorf("Token not found for userID %d: %s", userID, token)
}
if tokenInfo.Token != token {
return fmt.Errorf("Invalid token for userID %d: %s", userID, token)
}
if time.Since(tokenInfo.CreatedAt) > tokenTTL {
return fmt.Errorf("Token expired for userID %d: %s", userID, token)
}
return nil
}
func RefreshToken(userID int) (string, error) {
tokenMux.Lock()
defer tokenMux.Unlock()
randomBytes := make([]byte, 32)
_, err := rand.Read(randomBytes)
if err != nil {
return "", fmt.Errorf("Failed to generate random bytes: %w", err)
}
hash := sha256.Sum256(append(randomBytes, []byte(fmt.Sprintf("%d", userID))...))
token := base64.URLEncoding.EncodeToString(hash[:])
tokenMap[userID] = &TokenInfo{
Token: token,
CreatedAt: time.Now(),
UserID: userID,
}
postLog.Debug(fmt.Sprintf("[RefreshToken] Refreshed token for userID %d: %s", userID, token))
return token, nil
}
func RemoveToken(userID int) {
tokenMux.Lock()
defer tokenMux.Unlock()
delete(tokenMap, userID)
postLog.Debug(fmt.Sprintf("[RemoveToken] Removed token for userID %d: %s", userID, tokenMap[userID].Token))
}
func GetTokenInfo(userID int) (*TokenInfo, error) {
tokenMux.RLock()
defer tokenMux.RUnlock()
tokenInfo, exists := tokenMap[userID]
if !exists {
return nil, fmt.Errorf("Token not found for userID %d", userID)
}
return tokenInfo, nil
}
func extractUserIDFromToken(token string) (int, error) {
tokenMux.RLock()
defer tokenMux.RUnlock()
for userID, tokenInfo := range tokenMap {
if tokenInfo.Token == token {
return userID, nil
}
}
return 0, fmt.Errorf("Invalid token: %s", token)
}
func CleanupExpiredTokens() {
tokenMux.Lock()
defer tokenMux.Unlock()
for userID, tokenInfo := range tokenMap {
if time.Since(tokenInfo.CreatedAt) > tokenTTL {
delete(tokenMap, userID)
postLog.Debug(fmt.Sprintf("[CleanupExpiredTokens] Removed expired token for userID %d: %s", userID, tokenInfo.Token))
RemoveSession(userID, "")
}
}
}
func CleanupExpiredSessions() {
sessionMux.Lock()
defer sessionMux.Unlock()
for sessionID, session := range sessionMap {
if time.Now().After(session.ExpireAt) {
delete(sessionMap, sessionID)
postLog.Debug(fmt.Sprintf("[CleanupExpiredSessions] Removed expired session %s for user [%d]%s", sessionID, session.UserID, session.Username))
}
}
}
func hashPassword(password string) (string, error) {
hash := sha256.Sum256([]byte(password))
return hex.EncodeToString(hash[:]), nil
}
func verifyPassword(password, hashedPassword string) bool {
hash, err := hashPassword(password)
if err != nil {
postLog.Error(fmt.Sprintf("[verifyPassword] Failed to hash password: %v", err))
return false
}
return hash == hashedPassword
}
func isValidPassword(password string) bool { // Validate password complexity and generate hash
if len(password) < 8 {
return false
}
hasUpper := false
hasLower := false
hasDigit := false
hasSpecial := false
specialChars := "!@#$%^&*()_+-=[]{}|;:,.<>?"
for _, char := range password {
switch {
case char >= 'A' && char <= 'Z':
hasUpper = true
case char >= 'a' && char <= 'z':
hasLower = true
case char >= '0' && char <= '9':
hasDigit = true
case strings.ContainsRune(specialChars, char):
hasSpecial = true
}
}
return hasUpper && hasLower && hasDigit && hasSpecial
}
func ValidateTimeStamp(header http.Header) bool {
if isDebug {
return true
}
timeStampStr := header.Get("X-Timestamp")
if timeStampStr == "" {
return false
}
timeStamp, err := strconv.ParseInt(timeStampStr, 10, 64)
if err != nil {
return false
}
currentTime := time.Now().UnixMilli()
if currentTime-timeStamp > 3000 || timeStamp-currentTime > 3000 {
return false
}
return true
}
func GetUserIDFromToken(token string) (int, error) {
userID, err := extractUserIDFromToken(token)
if err != nil {
return 0, fmt.Errorf("Failed to extract userID from token: %w", err)
}
return userID, nil
}
func DeleteTokenInfo(userID int) error {
tokenMux.Lock()
defer tokenMux.Unlock()
delete(tokenMap, userID)
postLog.Debug(fmt.Sprintf("[DeleteTokenInfo] Removed token for userID %d", userID))
return nil
}
func GetUsernameByID(userID int) string {
user, err := GetUserByID(userID)
if err != nil {
return ""
}
return user.Username
}
func generateSessionID() string {
randomBytes := make([]byte, 16)
_, err := rand.Read(randomBytes)
if err != nil {
return fmt.Sprintf("session-%d-%s", time.Now().UnixNano(), fmt.Sprintf("%d", time.Now().Unix()))
}
return fmt.Sprintf("session-%x-%d", randomBytes, time.Now().UnixNano())
}
func JoinSession(userID int, userName string, token string) error {
sessionMux.Lock()
defer sessionMux.Unlock()
sessionID := generateSessionID()
session := &Session{
ID: sessionID,
UserID: userID,
Username: userName,
ExpireAt: time.Now().Add(sessionTTL),
}
sessionMap[sessionID] = session
postLog.Info(fmt.Sprintf("[JoinSession] User [%d]%s joined session %s with token %s", userID, userName, sessionID, token))
return nil
}
func RemoveSession(userID int, token string) error {
sessionMux.Lock()
defer sessionMux.Unlock()
var sessionIDToRemove string
for sessionID, session := range sessionMap {
if session.UserID == userID {
sessionIDToRemove = sessionID
break
}
}
if sessionIDToRemove != "" {
delete(sessionMap, sessionIDToRemove)
postLog.Info(fmt.Sprintf("[RemoveSession] Removed session %s for user [%d]%s", sessionIDToRemove, userID, GetUsernameByID(userID)))
}
tokenMux.Lock()
defer tokenMux.Unlock()
if tokenInfo, exists := tokenMap[userID]; exists {
if tokenInfo.Token == token || token == "" {
delete(tokenMap, userID)
postLog.Info(fmt.Sprintf("[RemoveSession] Removed token for user [%d]%s", userID, GetUsernameByID(userID)))
}
}
return nil
}
func ListActiveSessions() []*Session {
sessionMux.RLock()
defer sessionMux.RUnlock()
var activeSessions []*Session
for _, session := range sessionMap {
if session.ExpireAt.After(time.Now()) {
activeSessions = append(activeSessions, session)
}
}
return activeSessions
}