Files
backend/session/session.go
NanamiAdmin ddf91299e7 refactor: reorganize codebase into modular packages
feat(global): add global package for shared variables and types
refactor(handlers): move handlers to dedicated package and update imports
refactor(session): extract session management to separate package
refactor(config): move config handling to dedicated package
refactor(router): update route handlers to use new package structure
refactor(main): simplify main.go by moving logic to packages
2026-04-22 12:57:04 +08:00

372 lines
8.5 KiB
Go

package session
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"time"
"super-frpc/database"
"super-frpc/postLog"
)
type TokenInfo struct {
Token string
CreatedAt time.Time
UserID int
}
type Session struct {
ID string
UserID int
Username string
ExpireAt time.Time
}
var (
tokenMap = make(map[int]*TokenInfo)
tokenMux sync.RWMutex
tokenTTL = time.Hour
)
var (
sessionMap = make(map[string]*Session)
sessionMux sync.RWMutex
sessionTTL = time.Hour
)
var (
sessionTokenMap = make(map[string]string)
sessionTokenMux sync.RWMutex
)
func GenerateToken(userID int) (string, error) {
randomBytes := make([]byte, 32)
_, err := rand.Read(randomBytes)
if err != nil {
return "", fmt.Errorf("Failed to generate random bytes: %w", err)
}
hash := sha256.Sum256(append(randomBytes, []byte(fmt.Sprintf("%d", userID))...))
token := base64.URLEncoding.EncodeToString(hash[:])
tokenMux.Lock()
defer tokenMux.Unlock()
tokenMap[userID] = &TokenInfo{
Token: token,
CreatedAt: time.Now(),
UserID: userID,
}
postLog.Debug(fmt.Sprintf("[GenerateToken] Generated token for userID %d: %s", userID, token))
return token, nil
}
func ValidateToken(userID int, token string) error {
tokenMux.RLock()
defer tokenMux.RUnlock()
tokenInfo, exists := tokenMap[userID]
if !exists {
return fmt.Errorf("Token not found for userID %d: %s", userID, token)
}
if tokenInfo.Token != token {
return fmt.Errorf("Invalid token for userID %d: %s", userID, token)
}
if time.Since(tokenInfo.CreatedAt) > tokenTTL {
return fmt.Errorf("Token expired for userID %d: %s", userID, token)
}
return nil
}
func RefreshToken(userID int) (string, error) {
tokenMux.Lock()
defer tokenMux.Unlock()
randomBytes := make([]byte, 32)
_, err := rand.Read(randomBytes)
if err != nil {
return "", fmt.Errorf("Failed to generate random bytes: %w", err)
}
hash := sha256.Sum256(append(randomBytes, []byte(fmt.Sprintf("%d", userID))...))
token := base64.URLEncoding.EncodeToString(hash[:])
tokenMap[userID] = &TokenInfo{
Token: token,
CreatedAt: time.Now(),
UserID: userID,
}
postLog.Debug(fmt.Sprintf("[RefreshToken] Refreshed token for userID %d: %s", userID, token))
return token, nil
}
func RemoveToken(userID int) {
tokenMux.Lock()
defer tokenMux.Unlock()
delete(tokenMap, userID)
postLog.Debug(fmt.Sprintf("[RemoveToken] Removed token for userID %d: %s", userID, tokenMap[userID].Token))
}
func GetTokenInfo(userID int) (*TokenInfo, error) {
tokenMux.RLock()
defer tokenMux.RUnlock()
tokenInfo, exists := tokenMap[userID]
if !exists {
return nil, fmt.Errorf("Token not found for userID %d", userID)
}
return tokenInfo, nil
}
func ExtractUserIDFromToken(token string) (int, error) {
tokenMux.RLock()
defer tokenMux.RUnlock()
for userID, tokenInfo := range tokenMap {
if tokenInfo.Token == token {
return userID, nil
}
}
return 0, fmt.Errorf("Invalid token: %s", token)
}
func CleanupExpiredTokens() {
tokenMux.Lock()
defer tokenMux.Unlock()
for userID, tokenInfo := range tokenMap {
if time.Since(tokenInfo.CreatedAt) > tokenTTL {
delete(tokenMap, userID)
postLog.Debug(fmt.Sprintf("[CleanupExpiredTokens] Removed expired token for userID %d: %s", userID, tokenInfo.Token))
sessionTokenMux.Lock()
for sessionID, sessionToken := range sessionTokenMap {
if sessionToken == tokenInfo.Token {
delete(sessionTokenMap, sessionID)
delete(sessionMap, sessionID)
postLog.Debug(fmt.Sprintf("[CleanupExpiredTokens] Removed expired session %s for userID %d", sessionID, userID))
}
}
sessionTokenMux.Unlock()
}
}
}
func CleanupExpiredSessions() {
sessionMux.Lock()
defer sessionMux.Unlock()
sessionTokenMux.Lock()
defer sessionTokenMux.Unlock()
for sessionID, session := range sessionMap {
if time.Now().After(session.ExpireAt) {
delete(sessionMap, sessionID)
delete(sessionTokenMap, sessionID)
postLog.Debug(fmt.Sprintf("[CleanupExpiredSessions] Removed expired session %s for user [%d]%s", sessionID, session.UserID, session.Username))
}
}
}
func HashPassword(password string) (string, error) {
hash := sha256.Sum256([]byte(password))
return hex.EncodeToString(hash[:]), nil
}
func VerifyPassword(password, hashedPassword string) bool {
hash, err := HashPassword(password)
if err != nil {
postLog.Error(fmt.Sprintf("[verifyPassword] Failed to hash password: %v", err))
return false
}
return hash == hashedPassword
}
func IsValidPassword(password string) bool {
if len(password) < 8 {
return false
}
hasUpper := false
hasLower := false
hasDigit := false
hasSpecial := false
specialChars := "!@#$%^&*()_+-=[]{}|;:,.<>?"
for _, char := range password {
switch {
case char >= 'A' && char <= 'Z':
hasUpper = true
case char >= 'a' && char <= 'z':
hasLower = true
case char >= '0' && char <= '9':
hasDigit = true
case strings.ContainsRune(specialChars, char):
hasSpecial = true
}
}
return hasUpper && hasLower && hasDigit && hasSpecial
}
func ValidateTimeStamp(header http.Header, debug bool) bool {
if debug {
return true
}
timeStampStr := header.Get("X-Timestamp")
if timeStampStr == "" {
return false
}
timeStamp, err := strconv.ParseInt(timeStampStr, 10, 64)
if err != nil {
return false
}
currentTime := time.Now().UnixMilli()
if currentTime-timeStamp > 3000 || timeStamp-currentTime > 3000 {
return false
}
return true
}
func GetUserIDFromToken(token string) (int, error) {
userID, err := ExtractUserIDFromToken(token)
if err != nil {
return 0, fmt.Errorf("Failed to extract userID from token: %w", err)
}
return userID, nil
}
func DeleteTokenInfo(userID int) error {
tokenMux.Lock()
defer tokenMux.Unlock()
delete(tokenMap, userID)
postLog.Debug(fmt.Sprintf("[DeleteTokenInfo] Removed token for userID %d", userID))
return nil
}
func GetUsernameByID(userID int) string {
user, err := database.GetUserByID(userID)
if err != nil {
return ""
}
return user.Username
}
func generateSessionID() string {
randomBytes := make([]byte, 16)
_, err := rand.Read(randomBytes)
if err != nil {
return fmt.Sprintf("session-%d-%s", time.Now().UnixNano(), fmt.Sprintf("%d", time.Now().Unix()))
}
return fmt.Sprintf("session-%x-%d", randomBytes, time.Now().UnixNano())
}
func JoinSession(userID int, userName string, token string) error {
sessionMux.Lock()
defer sessionMux.Unlock()
sessionID := generateSessionID()
session := &Session{
ID: sessionID,
UserID: userID,
Username: userName,
ExpireAt: time.Now().Add(sessionTTL),
}
sessionMap[sessionID] = session
sessionTokenMap[sessionID] = token
postLog.Debug(fmt.Sprintf("[JoinSession] User [%d]%s joined session %s with token %s", userID, userName, sessionID, token))
return nil
}
func RemoveSession(sessionID string) error {
sessionMux.Lock()
defer sessionMux.Unlock()
session, exists := sessionMap[sessionID]
if !exists {
return fmt.Errorf("Session not found: %s", sessionID)
}
_, exists = sessionTokenMap[sessionID]
if !exists {
delete(sessionMap, sessionID)
return fmt.Errorf("Token not found for session: %s", sessionID)
}
delete(sessionMap, sessionID)
delete(sessionTokenMap, sessionID)
if tokenInfo, exists := tokenMap[session.UserID]; exists {
DeleteTokenInfo(session.UserID)
postLog.Info(fmt.Sprintf("[RemoveSession] Removed session '%s': '[%d]%s'", sessionID, tokenInfo.UserID, GetUsernameByID(tokenInfo.UserID)))
}
return nil
}
func ListActiveSessions() []*Session {
sessionMux.RLock()
defer sessionMux.RUnlock()
var activeSessions []*Session
for _, session := range sessionMap {
if session.ExpireAt.After(time.Now()) {
activeSessions = append(activeSessions, session)
}
}
return activeSessions
}
func ValidateTokenFromMap(token string) (int, error) {
tokenMux.RLock()
defer tokenMux.RUnlock()
for userID, tokenInfo := range tokenMap {
if tokenInfo.Token == token {
if time.Since(tokenInfo.CreatedAt) > tokenTTL {
return 0, fmt.Errorf("token expired")
}
return userID, nil
}
}
return 0, fmt.Errorf("invalid token")
}
func GetSessionTokenMap() *sync.RWMutex {
return &sessionTokenMux
}
func GetSessionTokenMapSnapshot() map[string]string {
sessionTokenMux.RLock()
defer sessionTokenMux.RUnlock()
snapshot := make(map[string]string)
for k, v := range sessionTokenMap {
snapshot[k] = v
}
return snapshot
}
func GetSessionToken(sessionID string) (string, bool) {
sessionTokenMux.RLock()
defer sessionTokenMux.RUnlock()
token, exists := sessionTokenMap[sessionID]
return token, exists
}