refactor: reorganize codebase into modular packages

feat(global): add global package for shared variables and types
refactor(handlers): move handlers to dedicated package and update imports
refactor(session): extract session management to separate package
refactor(config): move config handling to dedicated package
refactor(router): update route handlers to use new package structure
refactor(main): simplify main.go by moving logic to packages
This commit is contained in:
2026-04-22 12:57:04 +08:00
parent df8df78bab
commit ddf91299e7
11 changed files with 667 additions and 663 deletions

View File

@@ -1,4 +1,4 @@
package main package config
import ( import (
"encoding/json" "encoding/json"
@@ -58,7 +58,7 @@ type FrpcConfig struct {
var globalConfig *Config var globalConfig *Config
func LoadConfig(configPath string) (*Config, error) { func LoadConfig(configPath string, getInitSystem func() string) (*Config, error) {
data, err := os.ReadFile(configPath) data, err := os.ReadFile(configPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err) return nil, fmt.Errorf("failed to read config file: %w", err)
@@ -78,7 +78,7 @@ func LoadConfig(configPath string) (*Config, error) {
} }
if config.FrpcPath == "" { if config.FrpcPath == "" {
if GetInitSystem() == "windows" { if getInitSystem() == "windows" {
config.FrpcPath = "frp_client/frpc.exe" config.FrpcPath = "frp_client/frpc.exe"
} else { } else {
config.FrpcPath = "/usr/bin/frpc" config.FrpcPath = "/usr/bin/frpc"
@@ -122,7 +122,7 @@ func SaveConfig(configPath string, config *Config) error {
return nil return nil
} }
func decodeFrpcConfig(configContent string) (FrpcConfig, error) { func DecodeFrpcConfig(configContent string) (FrpcConfig, error) {
var config FrpcConfig var config FrpcConfig
var rawMap map[string]interface{} var rawMap map[string]interface{}
@@ -173,7 +173,7 @@ func decodeFrpcConfig(configContent string) (FrpcConfig, error) {
return config, nil return config, nil
} }
func encodeFrpcConfig(config FrpcConfig) (string, error) { func EncodeFrpcConfig(config FrpcConfig) (string, error) {
var buf strings.Builder var buf strings.Builder
if len(config.Global) > 0 { if len(config.Global) > 0 {
@@ -230,12 +230,12 @@ func encodeFrpcConfig(config FrpcConfig) (string, error) {
return strings.TrimSuffix(buf.String(), "\n"), nil return strings.TrimSuffix(buf.String(), "\n"), nil
} }
func handleConfigFileCreate(configPath string, info InstanceInfo) error { func HandleConfigFileCreate(configPath string, info InstanceInfo, setKeyText func(configPath, key, section, value string) error) error {
config := FrpcConfig{ config := FrpcConfig{
Global: make(map[string]interface{}), Global: make(map[string]interface{}),
} }
result, err := encodeFrpcConfig(config) result, err := EncodeFrpcConfig(config)
if err != nil { if err != nil {
return fmt.Errorf("failed to encode empty config: %w", err) return fmt.Errorf("failed to encode empty config: %w", err)
} }
@@ -290,8 +290,8 @@ func handleConfigFileCreate(configPath string, info InstanceInfo) error {
return nil return nil
} }
func addFrpcProxy(configContent string, info FrpcProxyInfo) (string, error) { func AddFrpcProxy(configContent string, info FrpcProxyInfo) (string, error) {
config, err := decodeFrpcConfig(configContent) config, err := DecodeFrpcConfig(configContent)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to parse config: %w", err) return "", fmt.Errorf("failed to parse config: %w", err)
} }
@@ -306,7 +306,7 @@ func addFrpcProxy(configContent string, info FrpcProxyInfo) (string, error) {
config.Proxies = append(config.Proxies, proxy) config.Proxies = append(config.Proxies, proxy)
result, err := encodeFrpcConfig(config) result, err := EncodeFrpcConfig(config)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to write config: %w", err) return "", fmt.Errorf("failed to write config: %w", err)
} }
@@ -314,8 +314,8 @@ func addFrpcProxy(configContent string, info FrpcProxyInfo) (string, error) {
return result, nil return result, nil
} }
func removeFrpcProxy(configContent string, proxyName string) (string, error) { func RemoveFrpcProxy(configContent string, proxyName string) (string, error) {
config, err := decodeFrpcConfig(configContent) config, err := DecodeFrpcConfig(configContent)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to parse config: %w", err) return "", fmt.Errorf("failed to parse config: %w", err)
} }
@@ -336,7 +336,7 @@ func removeFrpcProxy(configContent string, proxyName string) (string, error) {
config.Proxies = newProxies config.Proxies = newProxies
result, err := encodeFrpcConfig(config) result, err := EncodeFrpcConfig(config)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to write config: %w", err) return "", fmt.Errorf("failed to write config: %w", err)
} }
@@ -344,8 +344,8 @@ func removeFrpcProxy(configContent string, proxyName string) (string, error) {
return result, nil return result, nil
} }
func modifyFrpcProxy(configContent string, info FrpcProxyInfo) (string, error) { func ModifyFrpcProxy(configContent string, info FrpcProxyInfo) (string, error) {
config, err := decodeFrpcConfig(configContent) config, err := DecodeFrpcConfig(configContent)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to parse config: %w", err) return "", fmt.Errorf("failed to parse config: %w", err)
} }
@@ -369,7 +369,7 @@ func modifyFrpcProxy(configContent string, info FrpcProxyInfo) (string, error) {
return "", fmt.Errorf("proxy %s not found", info.Name) return "", fmt.Errorf("proxy %s not found", info.Name)
} }
result, err := encodeFrpcConfig(config) result, err := EncodeFrpcConfig(config)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to write config: %w", err) return "", fmt.Errorf("failed to write config: %w", err)
} }
@@ -377,13 +377,13 @@ func modifyFrpcProxy(configContent string, info FrpcProxyInfo) (string, error) {
return result, nil return result, nil
} }
func getKeyText(configPath, key, section string) (value string, err error) { func GetKeyText(configPath, key, section string, readFile func(string) ([]byte, error)) (value string, err error) {
configContent, err := os.ReadFile(configPath) configContent, err := readFile(configPath)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to read config file: %w", err) return "", fmt.Errorf("failed to read config file: %w", err)
} }
config, err := decodeFrpcConfig(string(configContent)) config, err := DecodeFrpcConfig(string(configContent))
if err != nil { if err != nil {
return "", fmt.Errorf("failed to parse config: %w", err) return "", fmt.Errorf("failed to parse config: %w", err)
} }
@@ -399,13 +399,13 @@ func getKeyText(configPath, key, section string) (value string, err error) {
return value, nil return value, nil
} }
func setKeyText(configPath, key, section string, value string) error { func SetKeyText(configPath, key, section string, value string, readFile func(string) ([]byte, error), writeFile func(string, []byte, os.FileMode) error) error {
configContent, err := os.ReadFile(configPath) configContent, err := readFile(configPath)
if err != nil { if err != nil {
return fmt.Errorf("failed to read config file: %w", err) return fmt.Errorf("failed to read config file: %w", err)
} }
config, err := decodeFrpcConfig(string(configContent)) config, err := DecodeFrpcConfig(string(configContent))
if err != nil { if err != nil {
return fmt.Errorf("failed to parse config: %w", err) return fmt.Errorf("failed to parse config: %w", err)
} }
@@ -420,12 +420,12 @@ func setKeyText(configPath, key, section string, value string) error {
config.Global[key] = value config.Global[key] = value
} }
result, err := encodeFrpcConfig(config) result, err := EncodeFrpcConfig(config)
if err != nil { if err != nil {
return fmt.Errorf("failed to write config: %w", err) return fmt.Errorf("failed to write config: %w", err)
} }
if err := os.WriteFile(configPath, []byte(result), 0644); err != nil { if err := writeFile(configPath, []byte(result), 0644); err != nil {
return fmt.Errorf("failed to write config file: %w", err) return fmt.Errorf("failed to write config file: %w", err)
} }

View File

@@ -1,4 +1,4 @@
package main package database
import ( import (
"database/sql" "database/sql"
@@ -6,19 +6,14 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
"super-frpc/postLog"
"time" "time"
"super-frpc/global"
"super-frpc/postLog"
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
) )
var db *sql.DB
var logsDB *sql.DB
func GetLogsDatabase() *sql.DB {
return logsDB
}
type User struct { type User struct {
UserID int UserID int
Username string Username string
@@ -47,19 +42,23 @@ func InitDatabase(dbPath_data string, dbPath_log string) error {
} }
func CloseDatabase() error { func CloseDatabase() error {
if db != nil { if global.ConfigDB != nil {
if err := db.Close(); err != nil { if err := global.ConfigDB.Close(); err != nil {
return err return err
} }
} }
if logsDB != nil { if global.LogsDB != nil {
if err := logsDB.Close(); err != nil { if err := global.LogsDB.Close(); err != nil {
return err return err
} }
} }
return nil return nil
} }
func IsValidInput(input string) bool {
return isValidInput(input)
}
func isValidInput(input string) bool { func isValidInput(input string) bool {
invalidChars := []string{"'", "\"", ";", "--", "/*", "*/", "xp_", "sp_", "EXEC", "EXECUTE", "DROP", "INSERT", "UPDATE", "DELETE", "SELECT"} invalidChars := []string{"'", "\"", ";", "--", "/*", "*/", "xp_", "sp_", "EXEC", "EXECUTE", "DROP", "INSERT", "UPDATE", "DELETE", "SELECT"}
lowerInput := strings.ToLower(input) lowerInput := strings.ToLower(input)
@@ -71,7 +70,7 @@ func isValidInput(input string) bool {
return true return true
} }
func AddUser(username, passwd, userType string) (int, error) { // New user registration with specified type func AddUser(username, passwd, userType string, hashPassword func(string) (string, error), isValidPassword func(string) bool) (int, error) {
if !isValidInput(username) || !isValidInput(passwd) || !isValidInput(userType) { if !isValidInput(username) || !isValidInput(passwd) || !isValidInput(userType) {
return 0, errors.New("invalid input: contains illegal characters") return 0, errors.New("invalid input: contains illegal characters")
} }
@@ -85,7 +84,7 @@ func AddUser(username, passwd, userType string) (int, error) { // New user regis
return 0, fmt.Errorf("failed to hash password: %w", err) return 0, fmt.Errorf("failed to hash password: %w", err)
} }
result, err := db.Exec("INSERT INTO userLogin (username, passwd, type) VALUES (?, ?, ?)", result, err := global.ConfigDB.Exec("INSERT INTO userLogin (username, passwd, type) VALUES (?, ?, ?)",
username, hashedPasswd, userType) username, hashedPasswd, userType)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint failed") { if strings.Contains(err.Error(), "UNIQUE constraint failed") {
@@ -100,7 +99,7 @@ func AddUser(username, passwd, userType string) (int, error) { // New user regis
} }
var count int var count int
err = db.QueryRow("SELECT COUNT(*) FROM userLogin WHERE userID = ?", lastID).Scan(&count) err = global.ConfigDB.QueryRow("SELECT COUNT(*) FROM userLogin WHERE userID = ?", lastID).Scan(&count)
if err != nil { if err != nil {
return 0, fmt.Errorf("failed to verify user insertion: %w", err) return 0, fmt.Errorf("failed to verify user insertion: %w", err)
} }
@@ -117,7 +116,7 @@ func RemoveUser(userID int) error {
return errors.New("invalid input: contains illegal characters") return errors.New("invalid input: contains illegal characters")
} }
result, err := db.Exec("DELETE FROM userLogin WHERE userID = ?", userID) result, err := global.ConfigDB.Exec("DELETE FROM userLogin WHERE userID = ?", userID)
if err != nil { if err != nil {
return fmt.Errorf("failed to delete user: %w", err) return fmt.Errorf("failed to delete user: %w", err)
} }
@@ -140,7 +139,7 @@ func GetUserByUsername(username string) (*User, error) {
} }
var user User var user User
err := db.QueryRow("SELECT userID, username, passwd, type, createdAt FROM userLogin WHERE username = ?", username). err := global.ConfigDB.QueryRow("SELECT userID, username, passwd, type, createdAt FROM userLogin WHERE username = ?", username).
Scan(&user.UserID, &user.Username, &user.Passwd, &user.Type, &user.CreatedAt) Scan(&user.UserID, &user.Username, &user.Passwd, &user.Type, &user.CreatedAt)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
@@ -154,7 +153,7 @@ func GetUserByUsername(username string) (*User, error) {
func GetUserByID(userID int) (*User, error) { func GetUserByID(userID int) (*User, error) {
var user User var user User
err := db.QueryRow("SELECT userID, username, passwd, type, createdAt FROM userLogin WHERE userID = ?", userID). err := global.ConfigDB.QueryRow("SELECT userID, username, passwd, type, createdAt FROM userLogin WHERE userID = ?", userID).
Scan(&user.UserID, &user.Username, &user.Passwd, &user.Type, &user.CreatedAt) Scan(&user.UserID, &user.Username, &user.Passwd, &user.Type, &user.CreatedAt)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
@@ -166,7 +165,7 @@ func GetUserByID(userID int) (*User, error) {
return &user, nil return &user, nil
} }
func UpdateUserPassword(userID int, newPasswd string) error { func UpdateUserPassword(userID int, newPasswd string, hashPassword func(string) (string, error), isValidPassword func(string) bool) error {
if !isValidInput(newPasswd) { if !isValidInput(newPasswd) {
return errors.New("invalid input: contains illegal characters") return errors.New("invalid input: contains illegal characters")
} }
@@ -180,7 +179,7 @@ func UpdateUserPassword(userID int, newPasswd string) error {
return fmt.Errorf("failed to hash password: %w", err) return fmt.Errorf("failed to hash password: %w", err)
} }
result, err := db.Exec("UPDATE userLogin SET passwd = ? WHERE userID = ?", hashedPasswd, userID) result, err := global.ConfigDB.Exec("UPDATE userLogin SET passwd = ? WHERE userID = ?", hashedPasswd, userID)
if err != nil { if err != nil {
return fmt.Errorf("failed to update password: %w", err) return fmt.Errorf("failed to update password: %w", err)
} }
@@ -208,7 +207,7 @@ func UpdateUserType(userID int, newType string) error {
return errors.New("invalid user type") return errors.New("invalid user type")
} }
result, err := db.Exec("UPDATE userLogin SET type = ? WHERE userID = ?", newType, userID) result, err := global.ConfigDB.Exec("UPDATE userLogin SET type = ? WHERE userID = ?", newType, userID)
if err != nil { if err != nil {
return fmt.Errorf("failed to update user type: %w", err) return fmt.Errorf("failed to update user type: %w", err)
} }
@@ -227,7 +226,7 @@ func UpdateUserType(userID int, newType string) error {
func GetNextAvailableUserID() (int, error) { func GetNextAvailableUserID() (int, error) {
var maxID int var maxID int
err := db.QueryRow("SELECT COALESCE(MAX(userID), 0) FROM userLogin").Scan(&maxID) err := global.ConfigDB.QueryRow("SELECT COALESCE(MAX(userID), 0) FROM userLogin").Scan(&maxID)
if err != nil { if err != nil {
return 0, fmt.Errorf("failed to get max userID: %w", err) return 0, fmt.Errorf("failed to get max userID: %w", err)
} }
@@ -235,7 +234,7 @@ func GetNextAvailableUserID() (int, error) {
} }
func GetAllUsers() ([]User, error) { func GetAllUsers() ([]User, error) {
rows, err := db.Query("SELECT userID, username, type FROM userLogin") rows, err := global.ConfigDB.Query("SELECT userID, username, type FROM userLogin")
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to query users: %w", err) return nil, fmt.Errorf("failed to query users: %w", err)
} }
@@ -259,12 +258,12 @@ func GetAllUsers() ([]User, error) {
func InitFrpcDatabase(dbPath string) error { func InitFrpcDatabase(dbPath string) error {
var err error var err error
frpcDB, err = sql.Open("sqlite", dbPath) global.FrpcDB, err = sql.Open("sqlite", dbPath)
if err != nil { if err != nil {
return fmt.Errorf("failed to open frpc database: %w", err) return fmt.Errorf("failed to open frpc database: %w", err)
} }
if err = frpcDB.Ping(); err != nil { if err = global.FrpcDB.Ping(); err != nil {
return fmt.Errorf("failed to ping frpc database: %w", err) return fmt.Errorf("failed to ping frpc database: %w", err)
} }
@@ -281,7 +280,7 @@ func InitFrpcDatabase(dbPath string) error {
UNIQUE(userID, name) UNIQUE(userID, name)
); );
` `
_, err = frpcDB.Exec(createTableSQL) _, err = global.FrpcDB.Exec(createTableSQL)
if err != nil { if err != nil {
return fmt.Errorf("failed to create frpcInstances table: %w", err) return fmt.Errorf("failed to create frpcInstances table: %w", err)
} }
@@ -291,12 +290,12 @@ func InitFrpcDatabase(dbPath string) error {
func InitUserDatabase(dbPath string) error { func InitUserDatabase(dbPath string) error {
var err error var err error
db, err = sql.Open("sqlite", dbPath) global.ConfigDB, err = sql.Open("sqlite", dbPath)
if err != nil { if err != nil {
return fmt.Errorf("failed to open database: %w", err) return fmt.Errorf("failed to open database: %w", err)
} }
if err = db.Ping(); err != nil { if err = global.ConfigDB.Ping(); err != nil {
return fmt.Errorf("failed to ping database: %w", err) return fmt.Errorf("failed to ping database: %w", err)
} }
@@ -309,7 +308,7 @@ func InitUserDatabase(dbPath string) error {
createdAt TEXT NOT NULL DEFAULT (datetime('now')) createdAt TEXT NOT NULL DEFAULT (datetime('now'))
); );
` `
_, err = db.Exec(createTableSQL) _, err = global.ConfigDB.Exec(createTableSQL)
if err != nil { if err != nil {
return fmt.Errorf("failed to create table: %w", err) return fmt.Errorf("failed to create table: %w", err)
} }
@@ -317,8 +316,8 @@ func InitUserDatabase(dbPath string) error {
return nil return nil
} }
func DBListUsers() ([]User, error) { // List all users func DBListUsers() ([]User, error) {
rows, err := db.Query("SELECT userID, username, type, createdAt FROM userLogin") rows, err := global.ConfigDB.Query("SELECT userID, username, type, createdAt FROM userLogin")
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to query users: %w", err) return nil, fmt.Errorf("failed to query users: %w", err)
} }
@@ -339,9 +338,9 @@ func DBListUsers() ([]User, error) { // List all users
return users, nil return users, nil
} }
func DBQuerySpecificUser(userID int) (User, error) { // Query user by ID func DBQuerySpecificUser(userID int) (User, error) {
var user User var user User
err := db.QueryRow("SELECT userID, username, type, createdAt FROM userLogin WHERE userID = ?", userID).Scan(&user.UserID, &user.Username, &user.Type, &user.CreatedAt) err := global.ConfigDB.QueryRow("SELECT userID, username, type, createdAt FROM userLogin WHERE userID = ?", userID).Scan(&user.UserID, &user.Username, &user.Type, &user.CreatedAt)
if err != nil { if err != nil {
return user, fmt.Errorf("failed to query user: %w", err) return user, fmt.Errorf("failed to query user: %w", err)
} }
@@ -351,9 +350,9 @@ func DBQuerySpecificUser(userID int) (User, error) { // Query user by ID
func DBUpdateUser(userID int, username, passwd string) error { func DBUpdateUser(userID int, username, passwd string) error {
var err error var err error
if passwd == "" { if passwd == "" {
_, err = db.Exec("UPDATE userLogin SET username = ? WHERE userID = ?", username, userID) _, err = global.ConfigDB.Exec("UPDATE userLogin SET username = ? WHERE userID = ?", username, userID)
} else { } else {
_, err = db.Exec("UPDATE userLogin SET username = ?, passwd = ? WHERE userID = ?", username, passwd, userID) _, err = global.ConfigDB.Exec("UPDATE userLogin SET username = ?, passwd = ? WHERE userID = ?", username, passwd, userID)
} }
if err != nil { if err != nil {
return fmt.Errorf("failed to update user: %w", err) return fmt.Errorf("failed to update user: %w", err)
@@ -362,7 +361,7 @@ func DBUpdateUser(userID int, username, passwd string) error {
} }
func DBUpdateUserType(userID int, newType string) error { func DBUpdateUserType(userID int, newType string) error {
_, err := db.Exec("UPDATE userLogin SET type = ? WHERE userID = ?", newType, userID) _, err := global.ConfigDB.Exec("UPDATE userLogin SET type = ? WHERE userID = ?", newType, userID)
if err != nil { if err != nil {
return fmt.Errorf("failed to update user type: %w", err) return fmt.Errorf("failed to update user type: %w", err)
} }
@@ -370,7 +369,7 @@ func DBUpdateUserType(userID int, newType string) error {
} }
func DBAddFrpcInstance(instance FrpcInstance) error { func DBAddFrpcInstance(instance FrpcInstance) error {
_, err := frpcDB.Exec("INSERT INTO frpcInstances (userID, name, bootAtStart, runUser, configPath, createdAt, watchdog) VALUES (?, ?, ?, ?, ?, ?, ?)", _, err := global.FrpcDB.Exec("INSERT INTO frpcInstances (userID, name, bootAtStart, runUser, configPath, createdAt, watchdog) VALUES (?, ?, ?, ?, ?, ?, ?)",
instance.UserID, instance.Name, instance.BootAtStart, instance.RunUser, instance.ConfigPath, time.Now().Format(time.RFC3339), instance.Watchdog) instance.UserID, instance.Name, instance.BootAtStart, instance.RunUser, instance.ConfigPath, time.Now().Format(time.RFC3339), instance.Watchdog)
if err != nil { if err != nil {
return fmt.Errorf("failed to insert frpc instance: %w", err) return fmt.Errorf("failed to insert frpc instance: %w", err)
@@ -381,7 +380,7 @@ func DBAddFrpcInstance(instance FrpcInstance) error {
func DBQueryFrpcInstanceByID(instanceID int) (FrpcInstance, error) { func DBQueryFrpcInstanceByID(instanceID int) (FrpcInstance, error) {
var instance FrpcInstance var instance FrpcInstance
var createdAtStr string var createdAtStr string
err := frpcDB.QueryRow("SELECT id, userID, name, bootAtStart, runUser, configPath, createdAt, watchdog FROM frpcInstances WHERE id = ?", instanceID).Scan( err := global.FrpcDB.QueryRow("SELECT id, userID, name, bootAtStart, runUser, configPath, createdAt, watchdog FROM frpcInstances WHERE id = ?", instanceID).Scan(
&instance.ID, &instance.UserID, &instance.Name, &instance.BootAtStart, &instance.RunUser, &instance.ConfigPath, &createdAtStr, &instance.Watchdog) &instance.ID, &instance.UserID, &instance.Name, &instance.BootAtStart, &instance.RunUser, &instance.ConfigPath, &createdAtStr, &instance.Watchdog)
if err != nil { if err != nil {
return instance, fmt.Errorf("failed to query frpc instance: %w", err) return instance, fmt.Errorf("failed to query frpc instance: %w", err)
@@ -393,7 +392,7 @@ func DBQueryFrpcInstanceByID(instanceID int) (FrpcInstance, error) {
func DBQueryFrpcInstance(userID int, instanceName string) (FrpcInstance, error) { func DBQueryFrpcInstance(userID int, instanceName string) (FrpcInstance, error) {
var instance FrpcInstance var instance FrpcInstance
var createdAtStr string var createdAtStr string
err := frpcDB.QueryRow("SELECT id, userID, name, bootAtStart, runUser, configPath, createdAt, watchdog FROM frpcInstances WHERE userID = ? AND name = ?", userID, instanceName).Scan( err := global.FrpcDB.QueryRow("SELECT id, userID, name, bootAtStart, runUser, configPath, createdAt, watchdog FROM frpcInstances WHERE userID = ? AND name = ?", userID, instanceName).Scan(
&instance.ID, &instance.UserID, &instance.Name, &instance.BootAtStart, &instance.RunUser, &instance.ConfigPath, &createdAtStr, &instance.Watchdog) &instance.ID, &instance.UserID, &instance.Name, &instance.BootAtStart, &instance.RunUser, &instance.ConfigPath, &createdAtStr, &instance.Watchdog)
if err != nil { if err != nil {
return instance, fmt.Errorf("failed to query frpc instance: %w", err) return instance, fmt.Errorf("failed to query frpc instance: %w", err)
@@ -403,7 +402,7 @@ func DBQueryFrpcInstance(userID int, instanceName string) (FrpcInstance, error)
} }
func DBRemoveFrpcInstanceByID(instanceID int) error { func DBRemoveFrpcInstanceByID(instanceID int) error {
_, err := frpcDB.Exec("DELETE FROM frpcInstances WHERE id = ?", instanceID) _, err := global.FrpcDB.Exec("DELETE FROM frpcInstances WHERE id = ?", instanceID)
if err != nil { if err != nil {
return fmt.Errorf("failed to delete frpc instance: %w", err) return fmt.Errorf("failed to delete frpc instance: %w", err)
} }
@@ -411,7 +410,7 @@ func DBRemoveFrpcInstanceByID(instanceID int) error {
} }
func DBUpdateFrpcInstance(instance FrpcInstance) error { func DBUpdateFrpcInstance(instance FrpcInstance) error {
_, err := frpcDB.Exec("UPDATE frpcInstances SET bootAtStart = ?, runUser = ?, configPath = ?, watchdog = ? WHERE id = ?", _, err := global.FrpcDB.Exec("UPDATE frpcInstances SET bootAtStart = ?, runUser = ?, configPath = ?, watchdog = ? WHERE id = ?",
instance.BootAtStart, instance.RunUser, instance.ConfigPath, instance.Watchdog, instance.ID) instance.BootAtStart, instance.RunUser, instance.ConfigPath, instance.Watchdog, instance.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to update frpc instance: %w", err) return fmt.Errorf("failed to update frpc instance: %w", err)
@@ -420,7 +419,7 @@ func DBUpdateFrpcInstance(instance FrpcInstance) error {
} }
func DBListFrpcInstances() ([]FrpcInstance, error) { func DBListFrpcInstances() ([]FrpcInstance, error) {
rows, err := frpcDB.Query(` rows, err := global.FrpcDB.Query(`
SELECT fi.id, fi.userID, fi.name, fi.bootAtStart, fi.runUser, fi.configPath, fi.createdAt, fi.watchdog, u.username SELECT fi.id, fi.userID, fi.name, fi.bootAtStart, fi.runUser, fi.configPath, fi.createdAt, fi.watchdog, u.username
FROM frpcInstances fi FROM frpcInstances fi
JOIN userLogin u ON fi.userID = u.userID JOIN userLogin u ON fi.userID = u.userID
@@ -448,6 +447,17 @@ func DBListFrpcInstances() ([]FrpcInstance, error) {
return instances, nil return instances, nil
} }
func DBQueryUserInstances(userID int) (*sql.Rows, error) {
rows, err := global.FrpcDB.Query(`
SELECT id, userID, name, bootAtStart, runUser, configPath, createdAt, watchdog
FROM frpcInstances WHERE userID = ?
`, userID)
if err != nil {
return nil, err
}
return rows, nil
}
func GetServiceNameByInstanceID(instanceID int) (string, error) { func GetServiceNameByInstanceID(instanceID int) (string, error) {
instance, err := DBQueryFrpcInstanceByID(instanceID) instance, err := DBQueryFrpcInstanceByID(instanceID)
if err != nil { if err != nil {

40
global/global.go Normal file
View File

@@ -0,0 +1,40 @@
package global
import (
"database/sql"
)
type SoftwareInfo struct {
Name string
Version string
Developer string
BuildVer int16
Description string
BuildType string
}
type StatusFlags struct {
Debug bool
Online bool
WatchdogConnected bool
}
var (
ConfigDB *sql.DB
FrpcDB *sql.DB
LogsDB *sql.DB
)
var Software = SoftwareInfo{
Name: "Super-frpc",
Version: "0.0.1",
Developer: "Madobi Nanami",
BuildVer: 1,
BuildType: "debug",
}
var Is = StatusFlags{
Debug: false,
Online: false,
WatchdogConnected: false,
}

View File

@@ -1,12 +1,16 @@
package main package handlers
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"strconv" "strconv"
"super-frpc/postLog"
"time" "time"
"super-frpc/database"
"super-frpc/global"
"super-frpc/postLog"
"super-frpc/session"
) )
type Response struct { type Response struct {
@@ -54,21 +58,21 @@ func Auth(w http.ResponseWriter, r *http.Request, targetMethod string, allowedUs
return 0, fmt.Errorf("Method not allowed: %s", targetMethod) return 0, fmt.Errorf("Method not allowed: %s", targetMethod)
} }
if !is.debug && !ValidateTimeStamp(r.Header) { if !global.Is.Debug && !session.ValidateTimeStamp(r.Header, global.Is.Debug) {
return 0, fmt.Errorf("Invalid or missing X-Timestamp in header") return 0, fmt.Errorf("Invalid or missing X-Timestamp in header")
} }
userID, err := extractUserIDFromToken(r.Header.Get("X-Token")) userID, err := session.ExtractUserIDFromToken(r.Header.Get("X-Token"))
if err != nil { if err != nil {
return 0, fmt.Errorf("Invalid token format: %w", err) return 0, fmt.Errorf("Invalid token format: %w", err)
} }
if err := ValidateToken(userID, r.Header.Get("X-Token")); err != nil { if err := session.ValidateToken(userID, r.Header.Get("X-Token")); err != nil {
return 0, fmt.Errorf("Token validation failed: %w", err) return 0, fmt.Errorf("Token validation failed: %w", err)
} }
if len(allowedUserLevels) > 0 { if len(allowedUserLevels) > 0 {
currentUser, err := DBQuerySpecificUser(userID) currentUser, err := database.DBQuerySpecificUser(userID)
if err != nil { if err != nil {
return 0, fmt.Errorf("Failed to query user: %w", err) return 0, fmt.Errorf("Failed to query user: %w", err)
} }
@@ -88,7 +92,7 @@ func Auth(w http.ResponseWriter, r *http.Request, targetMethod string, allowedUs
} }
func GetUserType(userID int) (string, error) { func GetUserType(userID int) (string, error) {
user, err := GetUserByID(userID) user, err := database.GetUserByID(userID)
if err != nil { if err != nil {
return "", fmt.Errorf("Failed to get user type: %w", err) return "", fmt.Errorf("Failed to get user type: %w", err)
} }

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
package main package instance
import ( import (
"encoding/json" "encoding/json"
@@ -7,15 +7,19 @@ import (
"net/http" "net/http"
"os" "os"
"strconv" "strconv"
"super-frpc/config"
"super-frpc/database"
"super-frpc/handlers"
"super-frpc/postLog" "super-frpc/postLog"
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
) )
func CreateProxyHandler(w http.ResponseWriter, r *http.Request) { func CreateProxyHandler(w http.ResponseWriter, r *http.Request) {
userID, err := Auth(w, r, http.MethodPost, "superuser", "admin") userID, err := handlers.Auth(w, r, http.MethodPost, "superuser", "admin")
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusUnauthorized, err.Error()) handlers.SendErrorResponse(w, http.StatusUnauthorized, err.Error())
postLog.Warning(fmt.Sprintf("[CreateProxyHandler] Auth failed: %v", err)) postLog.Warning(fmt.Sprintf("[CreateProxyHandler] Auth failed: %v", err))
return return
} }
@@ -23,7 +27,7 @@ func CreateProxyHandler(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
postLog.Error(fmt.Sprintf("[CreateProxyHandler] Failed to read request body: %v", err)) postLog.Error(fmt.Sprintf("[CreateProxyHandler] Failed to read request body: %v", err))
SendErrorResponse(w, http.StatusBadRequest, "Failed to read request body") handlers.SendErrorResponse(w, http.StatusBadRequest, "Failed to read request body")
return return
} }
defer r.Body.Close() defer r.Body.Close()
@@ -31,25 +35,25 @@ func CreateProxyHandler(w http.ResponseWriter, r *http.Request) {
var reqMap map[string]interface{} var reqMap map[string]interface{}
if err := json.Unmarshal(body, &reqMap); err != nil { if err := json.Unmarshal(body, &reqMap); err != nil {
postLog.Error(fmt.Sprintf("[CreateProxyHandler] Failed to unmarshal request body: %v", err)) postLog.Error(fmt.Sprintf("[CreateProxyHandler] Failed to unmarshal request body: %v", err))
SendErrorResponse(w, http.StatusBadRequest, "Invalid request format") handlers.SendErrorResponse(w, http.StatusBadRequest, "Invalid request format")
return return
} }
instanceID := getStringFromMap(reqMap, "instanceID") instanceID := getStringFromMap(reqMap, "instanceID")
if instanceID == "" { if instanceID == "" {
postLog.Error("[CreateProxyHandler] instanceID is required") postLog.Error("[CreateProxyHandler] instanceID is required")
SendErrorResponse(w, http.StatusBadRequest, "instanceID is required") handlers.SendErrorResponse(w, http.StatusBadRequest, "instanceID is required")
return return
} }
proxyInfoMap, ok := reqMap["proxyInfo"].(map[string]interface{}) proxyInfoMap, ok := reqMap["proxyInfo"].(map[string]interface{})
if !ok { if !ok {
postLog.Error("[CreateProxyHandler] Invalid proxyInfo format") postLog.Error("[CreateProxyHandler] Invalid proxyInfo format")
SendErrorResponse(w, http.StatusBadRequest, "Invalid proxyInfo format") handlers.SendErrorResponse(w, http.StatusBadRequest, "Invalid proxyInfo format")
return return
} }
proxyInfo := FrpcProxyInfo{ proxyInfo := config.FrpcProxyInfo{
Name: getStringFromMap(proxyInfoMap, "name"), Name: getStringFromMap(proxyInfoMap, "name"),
Type: getStringFromMap(proxyInfoMap, "type"), Type: getStringFromMap(proxyInfoMap, "type"),
LocalIP: getStringFromMap(proxyInfoMap, "localIP"), LocalIP: getStringFromMap(proxyInfoMap, "localIP"),
@@ -60,46 +64,46 @@ func CreateProxyHandler(w http.ResponseWriter, r *http.Request) {
if proxyInfo.Name == "" || proxyInfo.Type == "" || proxyInfo.LocalIP == "" || if proxyInfo.Name == "" || proxyInfo.Type == "" || proxyInfo.LocalIP == "" ||
proxyInfo.LocalPort == 0 || proxyInfo.RemotePort == 0 { proxyInfo.LocalPort == 0 || proxyInfo.RemotePort == 0 {
postLog.Error("[CreateProxyHandler] Missing required fields in proxyInfo") postLog.Error("[CreateProxyHandler] Missing required fields in proxyInfo")
SendErrorResponse(w, http.StatusBadRequest, "Missing required fields in proxyInfo") handlers.SendErrorResponse(w, http.StatusBadRequest, "Missing required fields in proxyInfo")
return return
} }
var instance FrpcInstance var instance database.FrpcInstance
instanceIDInt, _ := strconv.Atoi(instanceID) instanceIDInt, _ := strconv.Atoi(instanceID)
instance, err = DBQueryFrpcInstanceByID(instanceIDInt) instance, err = database.DBQueryFrpcInstanceByID(instanceIDInt)
if err != nil { if err != nil {
postLog.Error(fmt.Sprintf("[CreateProxyHandler] Failed to query instance: %v", err)) postLog.Error(fmt.Sprintf("[CreateProxyHandler] Failed to query instance: %v", err))
SendErrorResponse(w, http.StatusInternalServerError, "Failed to query instance") handlers.SendErrorResponse(w, http.StatusInternalServerError, "Failed to query instance")
return return
} }
if instance.UserID != userID { if instance.UserID != userID {
postLog.Error(fmt.Sprintf("[CreateProxyHandler] Instance not found for user %d", userID)) postLog.Error(fmt.Sprintf("[CreateProxyHandler] Instance not found for user %d", userID))
SendErrorResponse(w, http.StatusNotFound, "Instance not found") handlers.SendErrorResponse(w, http.StatusNotFound, "Instance not found")
return return
} }
configContent, err := os.ReadFile(instance.ConfigPath) configContent, err := os.ReadFile(instance.ConfigPath)
if err != nil { if err != nil {
postLog.Error(fmt.Sprintf("[CreateProxyHandler] Failed to read config file %s: %v", instance.ConfigPath, err)) postLog.Error(fmt.Sprintf("[CreateProxyHandler] Failed to read config file %s: %v", instance.ConfigPath, err))
SendErrorResponse(w, http.StatusInternalServerError, "Failed to read config file") handlers.SendErrorResponse(w, http.StatusInternalServerError, "Failed to read config file")
return return
} }
updatedContent, err := addFrpcProxy(string(configContent), proxyInfo) updatedContent, err := config.AddFrpcProxy(string(configContent), proxyInfo)
if err != nil { if err != nil {
postLog.Error(fmt.Sprintf("[CreateProxyHandler] Failed to add proxy: %v", err)) postLog.Error(fmt.Sprintf("[CreateProxyHandler] Failed to add proxy: %v", err))
SendErrorResponse(w, http.StatusInternalServerError, "Failed to add proxy") handlers.SendErrorResponse(w, http.StatusInternalServerError, "Failed to add proxy")
return return
} }
if err := os.WriteFile(instance.ConfigPath, []byte(updatedContent), 0644); err != nil { if err := os.WriteFile(instance.ConfigPath, []byte(updatedContent), 0644); err != nil {
postLog.Error(fmt.Sprintf("[CreateProxyHandler] Failed to write config file %s: %v", instance.ConfigPath, err)) postLog.Error(fmt.Sprintf("[CreateProxyHandler] Failed to write config file %s: %v", instance.ConfigPath, err))
SendErrorResponse(w, http.StatusInternalServerError, "Failed to write config file") handlers.SendErrorResponse(w, http.StatusInternalServerError, "Failed to write config file")
return return
} }
SendSuccessResponse(w, "Proxy created successfully", map[string]interface{}{ handlers.SendSuccessResponse(w, "Proxy created successfully", map[string]interface{}{
"instanceID": instance.ID, "instanceID": instance.ID,
"configPath": instance.ConfigPath, "configPath": instance.ConfigPath,
"proxyName": proxyInfo.Name, "proxyName": proxyInfo.Name,
@@ -108,9 +112,9 @@ func CreateProxyHandler(w http.ResponseWriter, r *http.Request) {
} }
func ModifyProxyHandler(w http.ResponseWriter, r *http.Request) { func ModifyProxyHandler(w http.ResponseWriter, r *http.Request) {
userID, err := Auth(w, r, http.MethodPost, "superuser", "admin") userID, err := handlers.Auth(w, r, http.MethodPost, "superuser", "admin")
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusUnauthorized, err.Error()) handlers.SendErrorResponse(w, http.StatusUnauthorized, err.Error())
postLog.Warning(fmt.Sprintf("[ModifyProxyHandler] Auth failed: %v", err)) postLog.Warning(fmt.Sprintf("[ModifyProxyHandler] Auth failed: %v", err))
return return
} }
@@ -118,7 +122,7 @@ func ModifyProxyHandler(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
postLog.Error(fmt.Sprintf("[ModifyProxyHandler] Failed to read request body: %v", err)) postLog.Error(fmt.Sprintf("[ModifyProxyHandler] Failed to read request body: %v", err))
SendErrorResponse(w, http.StatusBadRequest, "Failed to read request body") handlers.SendErrorResponse(w, http.StatusBadRequest, "Failed to read request body")
return return
} }
defer r.Body.Close() defer r.Body.Close()
@@ -126,25 +130,25 @@ func ModifyProxyHandler(w http.ResponseWriter, r *http.Request) {
var reqMap map[string]interface{} var reqMap map[string]interface{}
if err := json.Unmarshal(body, &reqMap); err != nil { if err := json.Unmarshal(body, &reqMap); err != nil {
postLog.Error(fmt.Sprintf("[ModifyProxyHandler] Failed to unmarshal request body: %v", err)) postLog.Error(fmt.Sprintf("[ModifyProxyHandler] Failed to unmarshal request body: %v", err))
SendErrorResponse(w, http.StatusBadRequest, "Invalid request format") handlers.SendErrorResponse(w, http.StatusBadRequest, "Invalid request format")
return return
} }
instanceID := getStringFromMap(reqMap, "instanceID") instanceID := getStringFromMap(reqMap, "instanceID")
if instanceID == "" { if instanceID == "" {
postLog.Error("[ModifyProxyHandler] instanceID is required") postLog.Error("[ModifyProxyHandler] instanceID is required")
SendErrorResponse(w, http.StatusBadRequest, "instanceID is required") handlers.SendErrorResponse(w, http.StatusBadRequest, "instanceID is required")
return return
} }
proxyInfoMap, ok := reqMap["proxyInfo"].(map[string]interface{}) proxyInfoMap, ok := reqMap["proxyInfo"].(map[string]interface{})
if !ok { if !ok {
postLog.Error("[ModifyProxyHandler] Invalid proxyInfo format") postLog.Error("[ModifyProxyHandler] Invalid proxyInfo format")
SendErrorResponse(w, http.StatusBadRequest, "Invalid proxyInfo format") handlers.SendErrorResponse(w, http.StatusBadRequest, "Invalid proxyInfo format")
return return
} }
proxyInfo := FrpcProxyInfo{ proxyInfo := config.FrpcProxyInfo{
Name: getStringFromMap(proxyInfoMap, "name"), Name: getStringFromMap(proxyInfoMap, "name"),
Type: getStringFromMap(proxyInfoMap, "type"), Type: getStringFromMap(proxyInfoMap, "type"),
LocalIP: getStringFromMap(proxyInfoMap, "localIP"), LocalIP: getStringFromMap(proxyInfoMap, "localIP"),
@@ -155,46 +159,46 @@ func ModifyProxyHandler(w http.ResponseWriter, r *http.Request) {
if proxyInfo.Name == "" || proxyInfo.Type == "" || proxyInfo.LocalIP == "" || if proxyInfo.Name == "" || proxyInfo.Type == "" || proxyInfo.LocalIP == "" ||
proxyInfo.LocalPort == 0 || proxyInfo.RemotePort == 0 { proxyInfo.LocalPort == 0 || proxyInfo.RemotePort == 0 {
postLog.Error("[ModifyProxyHandler] Missing required fields in proxyInfo") postLog.Error("[ModifyProxyHandler] Missing required fields in proxyInfo")
SendErrorResponse(w, http.StatusBadRequest, "Missing required fields in proxyInfo") handlers.SendErrorResponse(w, http.StatusBadRequest, "Missing required fields in proxyInfo")
return return
} }
var instance FrpcInstance var instance database.FrpcInstance
instanceIDInt, _ := strconv.Atoi(instanceID) instanceIDInt, _ := strconv.Atoi(instanceID)
instance, err = DBQueryFrpcInstanceByID(instanceIDInt) instance, err = database.DBQueryFrpcInstanceByID(instanceIDInt)
if err != nil { if err != nil {
postLog.Error(fmt.Sprintf("[ModifyProxyHandler] Failed to query instance: %v", err)) postLog.Error(fmt.Sprintf("[ModifyProxyHandler] Failed to query instance: %v", err))
SendErrorResponse(w, http.StatusInternalServerError, "Failed to query instance") handlers.SendErrorResponse(w, http.StatusInternalServerError, "Failed to query instance")
return return
} }
if instance.UserID != userID { if instance.UserID != userID {
postLog.Error(fmt.Sprintf("[ModifyProxyHandler] Instance not found for user %d", userID)) postLog.Error(fmt.Sprintf("[ModifyProxyHandler] Instance not found for user %d", userID))
SendErrorResponse(w, http.StatusNotFound, "Instance not found") handlers.SendErrorResponse(w, http.StatusNotFound, "Instance not found")
return return
} }
configContent, err := os.ReadFile(instance.ConfigPath) configContent, err := os.ReadFile(instance.ConfigPath)
if err != nil { if err != nil {
postLog.Error(fmt.Sprintf("[ModifyProxyHandler] Failed to read config file %s: %v", instance.ConfigPath, err)) postLog.Error(fmt.Sprintf("[ModifyProxyHandler] Failed to read config file %s: %v", instance.ConfigPath, err))
SendErrorResponse(w, http.StatusInternalServerError, "Failed to read config file") handlers.SendErrorResponse(w, http.StatusInternalServerError, "Failed to read config file")
return return
} }
updatedContent, err := modifyFrpcProxy(string(configContent), proxyInfo) updatedContent, err := config.ModifyFrpcProxy(string(configContent), proxyInfo)
if err != nil { if err != nil {
postLog.Error(fmt.Sprintf("[ModifyProxyHandler] Failed to modify proxy: %v", err)) postLog.Error(fmt.Sprintf("[ModifyProxyHandler] Failed to modify proxy: %v", err))
SendErrorResponse(w, http.StatusInternalServerError, "Failed to modify proxy") handlers.SendErrorResponse(w, http.StatusInternalServerError, "Failed to modify proxy")
return return
} }
if err := os.WriteFile(instance.ConfigPath, []byte(updatedContent), 0644); err != nil { if err := os.WriteFile(instance.ConfigPath, []byte(updatedContent), 0644); err != nil {
postLog.Error(fmt.Sprintf("[ModifyProxyHandler] Failed to write config file %s: %v", instance.ConfigPath, err)) postLog.Error(fmt.Sprintf("[ModifyProxyHandler] Failed to write config file %s: %v", instance.ConfigPath, err))
SendErrorResponse(w, http.StatusInternalServerError, "Failed to write config file") handlers.SendErrorResponse(w, http.StatusInternalServerError, "Failed to write config file")
return return
} }
SendSuccessResponse(w, "Proxy modified successfully", map[string]interface{}{ handlers.SendSuccessResponse(w, "Proxy modified successfully", map[string]interface{}{
"instanceID": instance.ID, "instanceID": instance.ID,
"configPath": instance.ConfigPath, "configPath": instance.ConfigPath,
"proxyName": proxyInfo.Name, "proxyName": proxyInfo.Name,
@@ -203,9 +207,9 @@ func ModifyProxyHandler(w http.ResponseWriter, r *http.Request) {
} }
func DeleteProxyHandler(w http.ResponseWriter, r *http.Request) { func DeleteProxyHandler(w http.ResponseWriter, r *http.Request) {
userID, err := Auth(w, r, http.MethodPost, "superuser", "admin") userID, err := handlers.Auth(w, r, http.MethodPost, "superuser", "admin")
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusUnauthorized, err.Error()) handlers.SendErrorResponse(w, http.StatusUnauthorized, err.Error())
postLog.Warning(fmt.Sprintf("[DeleteProxyHandler] Auth failed: %v", err)) postLog.Warning(fmt.Sprintf("[DeleteProxyHandler] Auth failed: %v", err))
return return
} }
@@ -213,7 +217,7 @@ func DeleteProxyHandler(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
postLog.Error(fmt.Sprintf("[DeleteProxyHandler] Failed to read request body: %v", err)) postLog.Error(fmt.Sprintf("[DeleteProxyHandler] Failed to read request body: %v", err))
SendErrorResponse(w, http.StatusBadRequest, "Failed to read request body") handlers.SendErrorResponse(w, http.StatusBadRequest, "Failed to read request body")
return return
} }
defer r.Body.Close() defer r.Body.Close()
@@ -221,60 +225,60 @@ func DeleteProxyHandler(w http.ResponseWriter, r *http.Request) {
var reqMap map[string]interface{} var reqMap map[string]interface{}
if err := json.Unmarshal(body, &reqMap); err != nil { if err := json.Unmarshal(body, &reqMap); err != nil {
postLog.Error(fmt.Sprintf("[DeleteProxyHandler] Failed to unmarshal request body: %v", err)) postLog.Error(fmt.Sprintf("[DeleteProxyHandler] Failed to unmarshal request body: %v", err))
SendErrorResponse(w, http.StatusBadRequest, "Invalid request format") handlers.SendErrorResponse(w, http.StatusBadRequest, "Invalid request format")
return return
} }
instanceID := getStringFromMap(reqMap, "instanceID") instanceID := getStringFromMap(reqMap, "instanceID")
if instanceID == "" { if instanceID == "" {
postLog.Error("[DeleteProxyHandler] instanceID is required") postLog.Error("[DeleteProxyHandler] instanceID is required")
SendErrorResponse(w, http.StatusBadRequest, "instanceID is required") handlers.SendErrorResponse(w, http.StatusBadRequest, "instanceID is required")
return return
} }
proxyName := getStringFromMap(reqMap, "proxyName") proxyName := getStringFromMap(reqMap, "proxyName")
if proxyName == "" { if proxyName == "" {
postLog.Error("[DeleteProxyHandler] proxyName is required") postLog.Error("[DeleteProxyHandler] proxyName is required")
SendErrorResponse(w, http.StatusBadRequest, "proxyName is required") handlers.SendErrorResponse(w, http.StatusBadRequest, "proxyName is required")
return return
} }
var instance FrpcInstance var instance database.FrpcInstance
instanceIDInt, _ := strconv.Atoi(instanceID) instanceIDInt, _ := strconv.Atoi(instanceID)
instance, err = DBQueryFrpcInstanceByID(instanceIDInt) instance, err = database.DBQueryFrpcInstanceByID(instanceIDInt)
if err != nil { if err != nil {
postLog.Error(fmt.Sprintf("[DeleteProxyHandler] Failed to query instance: %v", err)) postLog.Error(fmt.Sprintf("[DeleteProxyHandler] Failed to query instance: %v", err))
SendErrorResponse(w, http.StatusInternalServerError, "Failed to query instance") handlers.SendErrorResponse(w, http.StatusInternalServerError, "Failed to query instance")
return return
} }
if instance.UserID != userID { if instance.UserID != userID {
postLog.Error(fmt.Sprintf("[DeleteProxyHandler] Instance not found for user %d", userID)) postLog.Error(fmt.Sprintf("[DeleteProxyHandler] Instance not found for user %d", userID))
SendErrorResponse(w, http.StatusNotFound, "Instance not found") handlers.SendErrorResponse(w, http.StatusNotFound, "Instance not found")
return return
} }
configContent, err := os.ReadFile(instance.ConfigPath) configContent, err := os.ReadFile(instance.ConfigPath)
if err != nil { if err != nil {
postLog.Error(fmt.Sprintf("[DeleteProxyHandler] Failed to read config file %s: %v", instance.ConfigPath, err)) postLog.Error(fmt.Sprintf("[DeleteProxyHandler] Failed to read config file %s: %v", instance.ConfigPath, err))
SendErrorResponse(w, http.StatusInternalServerError, "Failed to read config file") handlers.SendErrorResponse(w, http.StatusInternalServerError, "Failed to read config file")
return return
} }
updatedContent, err := removeFrpcProxy(string(configContent), proxyName) updatedContent, err := config.RemoveFrpcProxy(string(configContent), proxyName)
if err != nil { if err != nil {
postLog.Error(fmt.Sprintf("[DeleteProxyHandler] Failed to remove proxy: %v", err)) postLog.Error(fmt.Sprintf("[DeleteProxyHandler] Failed to remove proxy: %v", err))
SendErrorResponse(w, http.StatusInternalServerError, "Failed to remove proxy") handlers.SendErrorResponse(w, http.StatusInternalServerError, "Failed to remove proxy")
return return
} }
if err := os.WriteFile(instance.ConfigPath, []byte(updatedContent), 0644); err != nil { if err := os.WriteFile(instance.ConfigPath, []byte(updatedContent), 0644); err != nil {
postLog.Error(fmt.Sprintf("[DeleteProxyHandler] Failed to write config file %s: %v", instance.ConfigPath, err)) postLog.Error(fmt.Sprintf("[DeleteProxyHandler] Failed to write config file %s: %v", instance.ConfigPath, err))
SendErrorResponse(w, http.StatusInternalServerError, "Failed to write config file") handlers.SendErrorResponse(w, http.StatusInternalServerError, "Failed to write config file")
return return
} }
SendSuccessResponse(w, "Proxy deleted successfully", map[string]interface{}{ handlers.SendSuccessResponse(w, "Proxy deleted successfully", map[string]interface{}{
"instanceID": instance.ID, "instanceID": instance.ID,
"configPath": instance.ConfigPath, "configPath": instance.ConfigPath,
"proxyName": proxyName, "proxyName": proxyName,
@@ -283,9 +287,9 @@ func DeleteProxyHandler(w http.ResponseWriter, r *http.Request) {
} }
func ListProxiesHandler(w http.ResponseWriter, r *http.Request) { func ListProxiesHandler(w http.ResponseWriter, r *http.Request) {
userID, err := Auth(w, r, http.MethodGet) userID, err := handlers.Auth(w, r, http.MethodGet)
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusUnauthorized, err.Error()) handlers.SendErrorResponse(w, http.StatusUnauthorized, err.Error())
postLog.Warning(fmt.Sprintf("[ListProxiesHandler] Auth failed: %v", err)) postLog.Warning(fmt.Sprintf("[ListProxiesHandler] Auth failed: %v", err))
return return
} }
@@ -294,41 +298,41 @@ func ListProxiesHandler(w http.ResponseWriter, r *http.Request) {
instanceID := queryParams.Get("instanceID") instanceID := queryParams.Get("instanceID")
if instanceID == "" { if instanceID == "" {
postLog.Error("[ListProxiesHandler] instanceID is required") postLog.Error("[ListProxiesHandler] instanceID is required")
SendErrorResponse(w, http.StatusBadRequest, "instanceID is required") handlers.SendErrorResponse(w, http.StatusBadRequest, "instanceID is required")
return return
} }
var instance FrpcInstance var instance database.FrpcInstance
instanceIDInt, _ := strconv.Atoi(instanceID) instanceIDInt, _ := strconv.Atoi(instanceID)
instance, err = DBQueryFrpcInstanceByID(instanceIDInt) instance, err = database.DBQueryFrpcInstanceByID(instanceIDInt)
if err != nil { if err != nil {
postLog.Error(fmt.Sprintf("[ListProxiesHandler] Failed to query instance: %v", err)) postLog.Error(fmt.Sprintf("[ListProxiesHandler] Failed to query instance: %v", err))
SendErrorResponse(w, http.StatusInternalServerError, "Failed to query instance") handlers.SendErrorResponse(w, http.StatusInternalServerError, "Failed to query instance")
return return
} }
if instance.UserID != userID { if instance.UserID != userID {
postLog.Error(fmt.Sprintf("[ListProxiesHandler] Instance not found for user %d", userID)) postLog.Error(fmt.Sprintf("[ListProxiesHandler] Instance not found for user %d", userID))
SendErrorResponse(w, http.StatusNotFound, "Instance not found") handlers.SendErrorResponse(w, http.StatusNotFound, "Instance not found")
return return
} }
configContent, err := os.ReadFile(instance.ConfigPath) configContent, err := os.ReadFile(instance.ConfigPath)
if err != nil { if err != nil {
postLog.Error(fmt.Sprintf("[ListProxiesHandler] Failed to read config file %s: %v", instance.ConfigPath, err)) postLog.Error(fmt.Sprintf("[ListProxiesHandler] Failed to read config file %s: %v", instance.ConfigPath, err))
SendErrorResponse(w, http.StatusInternalServerError, "Failed to read config file") handlers.SendErrorResponse(w, http.StatusInternalServerError, "Failed to read config file")
return return
} }
var config FrpcConfig var cfg config.FrpcConfig
if _, err := toml.Decode(string(configContent), &config); err != nil { if _, err := toml.Decode(string(configContent), &cfg); err != nil {
postLog.Error(fmt.Sprintf("[ListProxiesHandler] Failed to parse config: %v", err)) postLog.Error(fmt.Sprintf("[ListProxiesHandler] Failed to parse config: %v", err))
SendErrorResponse(w, http.StatusInternalServerError, "Failed to parse config file") handlers.SendErrorResponse(w, http.StatusInternalServerError, "Failed to parse config file")
return return
} }
proxyList := make([]map[string]interface{}, len(config.Proxies)) proxyList := make([]map[string]interface{}, len(cfg.Proxies))
for i, proxy := range config.Proxies { for i, proxy := range cfg.Proxies {
proxyData := map[string]interface{}{ proxyData := map[string]interface{}{
"name": proxy["name"], "name": proxy["name"],
"type": proxy["type"], "type": proxy["type"],
@@ -339,7 +343,7 @@ func ListProxiesHandler(w http.ResponseWriter, r *http.Request) {
proxyList[i] = proxyData proxyList[i] = proxyData
} }
SendSuccessResponse(w, "Proxies listed successfully", map[string]interface{}{ handlers.SendSuccessResponse(w, "Proxies listed successfully", map[string]interface{}{
"instanceID": instance.ID, "instanceID": instance.ID,
"proxyCount": len(proxyList), "proxyCount": len(proxyList),
"proxies": proxyList, "proxies": proxyList,

102
main.go
View File

@@ -6,49 +6,21 @@ import (
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
"super-frpc/config"
"super-frpc/database"
"super-frpc/frpLogger" "super-frpc/frpLogger"
"super-frpc/global"
"super-frpc/handlers"
"super-frpc/postLog" "super-frpc/postLog"
"super-frpc/session"
"super-frpc/service"
"super-frpc/watchdog" "super-frpc/watchdog"
"syscall" "syscall"
"time" "time"
) )
type SoftwareInfo struct {
Name string
Version string
Developer string
BuildVer int16
Description string
BuildType string
}
var softwareInfo SoftwareInfo = SoftwareInfo{
Name: "Super-frpc",
Version: "0.0.1",
Developer: "Madobi Nanami",
BuildVer: 1,
BuildType: "debug",
}
type StatusInfo struct {
ServerStatus string
WatchdogStatus string
}
type Is struct {
debug bool
online bool
watchdogConnected bool
}
var is Is = Is {
debug: false,
online: false,
watchdogConnected: false,
}
func main() { func main() {
postLog.Info(fmt.Sprintf("%s %s (Build %d.%s) by %s", softwareInfo.Name, softwareInfo.Version, softwareInfo.BuildVer, softwareInfo.BuildType, softwareInfo.Developer)) postLog.Info(fmt.Sprintf("%s %s (Build %d.%s) by %s", global.Software.Name, global.Software.Version, global.Software.BuildVer, global.Software.BuildType, global.Software.Developer))
configPath := flag.String("config", "./config.json", "path to config file") configPath := flag.String("config", "./config.json", "path to config file")
dbPath_data := flag.String("db", "./database.db", "path to database file") dbPath_data := flag.String("db", "./database.db", "path to database file")
dbPath_log := flag.String("log", "./logs.db", "path to logs database file") dbPath_log := flag.String("log", "./logs.db", "path to logs database file")
@@ -66,32 +38,32 @@ func main() {
postLog.Info(fmt.Sprintf("Created default config file at %s", *configPath)) postLog.Info(fmt.Sprintf("Created default config file at %s", *configPath))
} }
config, err := LoadConfig(*configPath) cfg, err := config.LoadConfig(*configPath, service.GetInitSystem)
if err != nil { if err != nil {
postLog.Fatal(fmt.Sprintf("Failed to load config: %v", err)) postLog.Fatal(fmt.Sprintf("Failed to load config: %v", err))
} }
postLog.SetDebugMode(config.Debug) postLog.SetDebugMode(cfg.Debug)
is.debug = config.Debug global.Is.Debug = cfg.Debug
if err := postLog.InitLogsDatabase(*dbPath_log); err != nil { if err := postLog.InitLogsDatabase(*dbPath_log); err != nil {
postLog.Fatal(fmt.Sprintf("Failed to initialize logs database: %v", err)) postLog.Fatal(fmt.Sprintf("Failed to initialize logs database: %v", err))
} }
postLog.Info("Logs database initialized successfully") postLog.Info("Logs database initialized successfully")
if err := InitDatabase(*dbPath_data, *dbPath_log); err != nil { if err := database.InitDatabase(*dbPath_data, *dbPath_log); err != nil {
postLog.Fatal(fmt.Sprintf("Failed to initialize database: %v", err)) postLog.Fatal(fmt.Sprintf("Failed to initialize database: %v", err))
} }
postLog.Info("Database initialized successfully") postLog.Info("Database initialized successfully")
if err := InitFrpcDatabase(*dbPath_data); err != nil { if err := database.InitFrpcDatabase(*dbPath_data); err != nil {
postLog.Fatal(fmt.Sprintf("Failed to initialize frpc database: %v", err)) postLog.Fatal(fmt.Sprintf("Failed to initialize frpc database: %v", err))
} }
frpLogger.SetDatabase(db, frpcDB) frpLogger.SetDatabase(global.ConfigDB, global.FrpcDB)
frpLogger.SetDebugMode(config.Debug) frpLogger.SetDebugMode(cfg.Debug)
_, err = GetConfig() _, err = config.GetConfig()
if err != nil { if err != nil {
postLog.Fatal(fmt.Sprintf("Failed to get config: %v", err)) postLog.Fatal(fmt.Sprintf("Failed to get config: %v", err))
} }
@@ -104,15 +76,15 @@ func main() {
if err != nil { if err != nil {
postLog.Error(fmt.Sprintf("Unable to initialize Watchdog: %s", err)) postLog.Error(fmt.Sprintf("Unable to initialize Watchdog: %s", err))
} else { } else {
if !watchdog.Connect("127.0.0.1", config.Watchdog.Port) { if !watchdog.Connect("127.0.0.1", cfg.Watchdog.Port) {
postLog.Error(fmt.Sprintf("Failed to connect to Watchdog at %s:%d", "127.0.0.1", config.Watchdog.Port)) postLog.Error(fmt.Sprintf("Failed to connect to Watchdog at %s:%d", "127.0.0.1", cfg.Watchdog.Port))
} else { } else {
postLog.Info(fmt.Sprintf("Connected to Watchdog at %s:%d", "127.0.0.1", config.Watchdog.Port)) postLog.Info(fmt.Sprintf("Connected to Watchdog at %s:%d", "127.0.0.1", cfg.Watchdog.Port))
is.watchdogConnected = true global.Is.WatchdogConnected = true
} }
} }
addr := fmt.Sprintf("%s:%s", config.ListenAddr, config.ListenPort) addr := fmt.Sprintf("%s:%s", cfg.ListenAddr, cfg.ListenPort)
server := &http.Server{ server := &http.Server{
Addr: addr, Addr: addr,
ReadTimeout: 15 * time.Second, ReadTimeout: 15 * time.Second,
@@ -122,7 +94,7 @@ func main() {
go func() { go func() {
postLog.Info(fmt.Sprintf("Server starting on %s", addr)) postLog.Info(fmt.Sprintf("Server starting on %s", addr))
is.online = true global.Is.Online = true
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
postLog.Fatal(fmt.Sprintf("Failed to start server: %v", err)) postLog.Fatal(fmt.Sprintf("Failed to start server: %v", err))
} }
@@ -132,8 +104,8 @@ func main() {
ticker := time.NewTicker(1 * time.Hour) ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop() defer ticker.Stop()
for range ticker.C { for range ticker.C {
CleanupExpiredTokens() session.CleanupExpiredTokens()
CleanupExpiredSessions() session.CleanupExpiredSessions()
} }
}() }()
@@ -147,37 +119,29 @@ func main() {
postLog.Error(fmt.Sprintf("Server closed with error: %v", err)) postLog.Error(fmt.Sprintf("Server closed with error: %v", err))
} }
if err := CloseDatabase(); err != nil { if err := database.CloseDatabase(); err != nil {
postLog.Error(fmt.Sprintf("Error closing database: %v", err)) postLog.Error(fmt.Sprintf("Error closing database: %v", err))
} }
if err := CloseFrpcDatabase(); err != nil {
postLog.Error(fmt.Sprintf("Error closing frpc database: %v", err))
}
watchdog.Close() watchdog.Close()
postLog.Info("Server stopped") postLog.Info("Server stopped")
} }
func GetStatusHandler(w http.ResponseWriter, r *http.Request) { func GetStatusHandler(w http.ResponseWriter, r *http.Request) {
statusInfo := StatusInfo{ statusInfo := map[string]string{
ServerStatus: "Offline", "ServerStatus": "Offline",
WatchdogStatus: "Offline", "WatchdogStatus": "Offline",
} }
if !is.online { if global.Is.Online {
statusInfo.ServerStatus = "Offline" statusInfo["ServerStatus"] = "Online"
} else {
statusInfo.ServerStatus = "Online"
} }
if !is.watchdogConnected { if global.Is.WatchdogConnected {
statusInfo.WatchdogStatus = "Offline" statusInfo["WatchdogStatus"] = "Online"
} else {
statusInfo.WatchdogStatus = "Online"
} }
SendSuccessResponse(w, "getStatus", statusInfo) handlers.SendSuccessResponse(w, "getStatus", statusInfo)
} }
func GetSoftwareInfoHandler(w http.ResponseWriter, r *http.Request) { func GetSoftwareInfoHandler(w http.ResponseWriter, r *http.Request) {
SendSuccessResponse(w, "getSoftwareInfo", softwareInfo) handlers.SendSuccessResponse(w, "getSoftwareInfo", global.Software)
} }

View File

@@ -4,7 +4,11 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"super-frpc/frpLogger" "super-frpc/frpLogger"
"super-frpc/handlers"
"super-frpc/instance"
"super-frpc/postLog" "super-frpc/postLog"
"super-frpc/session"
"super-frpc/user"
) )
func setupRoutes() { func setupRoutes() {
@@ -13,33 +17,33 @@ func setupRoutes() {
systemLogHandler := postLog.NewLogSocketHandler(postLog.GetLogBroadcaster()) systemLogHandler := postLog.NewLogSocketHandler(postLog.GetLogBroadcaster())
http.HandleFunc("/system/getLogs", systemLogHandler.Handle) http.HandleFunc("/system/getLogs", systemLogHandler.Handle)
http.HandleFunc("/register", RegisterHandler) http.HandleFunc("/register", user.RegisterHandler)
http.HandleFunc("/login", LoginHandler) http.HandleFunc("/login", user.LoginHandler)
http.HandleFunc("/logout", LogoutHandler) http.HandleFunc("/logout", user.LogoutHandler)
http.HandleFunc("/userMgr/create", CreateUserHandler) http.HandleFunc("/userMgr/create", user.CreateUserHandler)
http.HandleFunc("/userMgr/remove", RemoveUserHandler) http.HandleFunc("/userMgr/remove", user.RemoveUserHandler)
http.HandleFunc("/userMgr/list", ListUserHandler) http.HandleFunc("/userMgr/list", user.ListUserHandler)
http.HandleFunc("/userMgr/modify", ModifyUserHandler) http.HandleFunc("/userMgr/modify", user.ModifyUserHandler)
http.HandleFunc("/userMgr/modifyType", ModifyUserTypeHandler) http.HandleFunc("/userMgr/modifyType", user.ModifyUserTypeHandler)
http.HandleFunc("/sessionMgr/list", ListActiveSessionsHandler) http.HandleFunc("/sessionMgr/list", user.ListActiveSessionsHandler)
http.HandleFunc("/sessionMgr/remove", RemoveSessionHandler) http.HandleFunc("/sessionMgr/remove", user.RemoveSessionHandler)
http.HandleFunc("/frpcAct/instanceMgr/create", CreateInstanceHandler) http.HandleFunc("/frpcAct/instanceMgr/create", instance.CreateInstanceHandler)
http.HandleFunc("/frpcAct/instanceMgr/delete", DeleteInstanceHandler) http.HandleFunc("/frpcAct/instanceMgr/delete", instance.DeleteInstanceHandler)
http.HandleFunc("/frpcAct/instanceMgr/modify", ModifyInstanceHandler) http.HandleFunc("/frpcAct/instanceMgr/modify", instance.ModifyInstanceHandler)
http.HandleFunc("/frpcAct/instanceMgr/list", ListInstancesHandler) http.HandleFunc("/frpcAct/instanceMgr/list", instance.ListInstancesHandler)
http.HandleFunc("/frpcAct/instanceMgr/start", StartInstanceHandler) http.HandleFunc("/frpcAct/instanceMgr/start", instance.StartInstanceHandler)
http.HandleFunc("/frpcAct/instanceMgr/stop", StopInstanceHandler) http.HandleFunc("/frpcAct/instanceMgr/stop", instance.StopInstanceHandler)
http.HandleFunc("/frpcAct/instanceMgr/restart", RestartInstanceHandler) http.HandleFunc("/frpcAct/instanceMgr/restart", instance.RestartInstanceHandler)
http.HandleFunc("/frpcAct/instanceMgr/status", GetInstanceStatusHandler) http.HandleFunc("/frpcAct/instanceMgr/status", instance.GetInstanceStatusHandler)
http.HandleFunc("/frpcAct/instanceMgr/getInfo", GetInstanceInfoHandler) http.HandleFunc("/frpcAct/instanceMgr/getInfo", instance.GetInstanceInfoHandler)
http.HandleFunc("/frpcAct/instanceMgr/logs", frpLogger.NewInstanceLogHandler(ValidateTokenFromMap).ServeHTTP) http.HandleFunc("/frpcAct/instanceMgr/logs", frpLogger.NewInstanceLogHandler(session.ValidateTokenFromMap).ServeHTTP)
http.HandleFunc("/frpcAct/proxyMgr/create", CreateProxyHandler) http.HandleFunc("/frpcAct/proxyMgr/create", instance.CreateProxyHandler)
http.HandleFunc("/frpcAct/proxyMgr/modify", ModifyProxyHandler) http.HandleFunc("/frpcAct/proxyMgr/modify", instance.ModifyProxyHandler)
http.HandleFunc("/frpcAct/proxyMgr/delete", DeleteProxyHandler) http.HandleFunc("/frpcAct/proxyMgr/delete", instance.DeleteProxyHandler)
http.HandleFunc("/frpcAct/proxyMgr/list", ListProxiesHandler) http.HandleFunc("/frpcAct/proxyMgr/list", instance.ListProxiesHandler)
http.HandleFunc("/", NotFoundHandler) http.HandleFunc("/", NotFoundHandler)
@@ -49,5 +53,5 @@ func setupRoutes() {
func NotFoundHandler(w http.ResponseWriter, r *http.Request) { func NotFoundHandler(w http.ResponseWriter, r *http.Request) {
postLog.Error(fmt.Sprintf("Route not found: %s %s", r.Method, r.URL.Path)) postLog.Error(fmt.Sprintf("Route not found: %s %s", r.Method, r.URL.Path))
SendErrorResponse(w, http.StatusNotFound, "Invalid request path") handlers.SendErrorResponse(w, http.StatusNotFound, "Invalid request path")
} }

View File

@@ -1,4 +1,4 @@
package main package service
import ( import (
"fmt" "fmt"
@@ -8,25 +8,28 @@ import (
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"super-frpc/config"
"super-frpc/database"
"super-frpc/postLog" "super-frpc/postLog"
) )
func GetConfigDir() (string, error) { func GetConfigDir() (string, error) {
config, err := GetConfig() cfg, err := config.GetConfig()
if err != nil { if err != nil {
postLog.Error(fmt.Sprintf("[GetConfigDir] Failed to get config: %v", err)) postLog.Error(fmt.Sprintf("[GetConfigDir] Failed to get config: %v", err))
return "", err return "", err
} }
return config.InstancePath, nil return cfg.InstancePath, nil
} }
func GetFrpcPath() (string, error) { func GetFrpcPath() (string, error) {
config, err := GetConfig() cfg, err := config.GetConfig()
if err != nil { if err != nil {
postLog.Error(fmt.Sprintf("[GetFrpcPath] Failed to get config: %v", err)) postLog.Error(fmt.Sprintf("[GetFrpcPath] Failed to get config: %v", err))
return "", err return "", err
} }
return config.FrpcPath, nil return cfg.FrpcPath, nil
} }
func GetInitSystem() string { func GetInitSystem() string {
@@ -44,8 +47,8 @@ func GetInitSystem() string {
return "unknown" return "unknown"
} }
func createBootService(instanceID int) error { func CreateBootService(instanceID int) error {
instance, err := DBQueryFrpcInstanceByID(instanceID) instance, err := database.DBQueryFrpcInstanceByID(instanceID)
if err != nil { if err != nil {
return fmt.Errorf("failed to query frpc instance: %w", err) return fmt.Errorf("failed to query frpc instance: %w", err)
} }
@@ -70,17 +73,17 @@ func createSystemdService(instanceID int, configPath, runUser string) error {
return err return err
} }
serviceName, err := GetServiceNameByInstanceID(instanceID) serviceName, err := database.GetServiceNameByInstanceID(instanceID)
if err != nil { if err != nil {
return err return err
} }
instance, err := DBQueryFrpcInstanceByID(instanceID) instance, err := database.DBQueryFrpcInstanceByID(instanceID)
if err != nil { if err != nil {
return fmt.Errorf("failed to query frpc instance: %w", err) return fmt.Errorf("failed to query frpc instance: %w", err)
} }
user, err := GetUserByID(instance.UserID) user, err := database.GetUserByID(instance.UserID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get user info: %w", err) return fmt.Errorf("failed to get user info: %w", err)
} }
@@ -125,17 +128,17 @@ func createInitDService(instanceID int, configPath, runUser string) error {
return err return err
} }
serviceName, err := GetServiceNameByInstanceID(instanceID) serviceName, err := database.GetServiceNameByInstanceID(instanceID)
if err != nil { if err != nil {
return err return err
} }
instance, err := DBQueryFrpcInstanceByID(instanceID) instance, err := database.DBQueryFrpcInstanceByID(instanceID)
if err != nil { if err != nil {
return fmt.Errorf("failed to query frpc instance: %w", err) return fmt.Errorf("failed to query frpc instance: %w", err)
} }
user, err := GetUserByID(instance.UserID) user, err := database.GetUserByID(instance.UserID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get user info: %w", err) return fmt.Errorf("failed to get user info: %w", err)
} }
@@ -268,7 +271,7 @@ func createWindowsBootService(instanceID int, configPath string) error {
return err return err
} }
serviceName, err := GetServiceNameByInstanceID(instanceID) serviceName, err := database.GetServiceNameByInstanceID(instanceID)
if err != nil { if err != nil {
return err return err
} }
@@ -285,9 +288,9 @@ func createWindowsBootService(instanceID int, configPath string) error {
return nil return nil
} }
func setBootAtStart(instanceID int) error { func SetBootAtStart(instanceID int) error {
initType := GetInitSystem() initType := GetInitSystem()
serviceName, err := GetServiceNameByInstanceID(instanceID) serviceName, err := database.GetServiceNameByInstanceID(instanceID)
if err != nil { if err != nil {
return err return err
} }
@@ -319,9 +322,9 @@ func setBootAtStart(instanceID int) error {
} }
} }
func removeBootAtStart(instanceID int) error { func RemoveBootAtStart(instanceID int) error {
initType := GetInitSystem() initType := GetInitSystem()
serviceName, err := GetServiceNameByInstanceID(instanceID) serviceName, err := database.GetServiceNameByInstanceID(instanceID)
if err != nil { if err != nil {
return err return err
} }
@@ -353,9 +356,9 @@ func removeBootAtStart(instanceID int) error {
} }
} }
func removeBootService(instanceID int) error { func RemoveBootService(instanceID int) error {
initType := GetInitSystem() initType := GetInitSystem()
serviceName, err := GetServiceNameByInstanceID(instanceID) serviceName, err := database.GetServiceNameByInstanceID(instanceID)
if err != nil { if err != nil {
return err return err
} }
@@ -393,13 +396,13 @@ func removeBootService(instanceID int) error {
func IsInstanceRunning(instanceID int) error { func IsInstanceRunning(instanceID int) error {
initType := GetInitSystem() initType := GetInitSystem()
serviceName, err := GetServiceNameByInstanceID(instanceID) serviceName, err := database.GetServiceNameByInstanceID(instanceID)
if err != nil { if err != nil {
postLog.Error(fmt.Sprintf("[IsInstanceRunning] Failed to get service name: %v", err)) postLog.Error(fmt.Sprintf("[IsInstanceRunning] Failed to get service name: %v", err))
return err return err
} }
instance, err := DBQueryFrpcInstanceByID(instanceID) instance, err := database.DBQueryFrpcInstanceByID(instanceID)
if err != nil { if err != nil {
postLog.Error(fmt.Sprintf("[IsInstanceRunning] Failed to query instance: %v", err)) postLog.Error(fmt.Sprintf("[IsInstanceRunning] Failed to query instance: %v", err))
return err return err
@@ -504,7 +507,7 @@ func IsInstanceRunning(instanceID int) error {
func GetInstancePid(instanceID int) int { func GetInstancePid(instanceID int) int {
initType := GetInitSystem() initType := GetInitSystem()
serviceName, err := GetServiceNameByInstanceID(instanceID) serviceName, err := database.GetServiceNameByInstanceID(instanceID)
if err != nil { if err != nil {
postLog.Error(fmt.Sprintf("[GetInstancePid] Failed to get service name: %v", err)) postLog.Error(fmt.Sprintf("[GetInstancePid] Failed to get service name: %v", err))
return 0 return 0

View File

@@ -1,4 +1,4 @@
package main package session
import ( import (
"crypto/rand" "crypto/rand"
@@ -9,9 +9,11 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"super-frpc/postLog"
"sync" "sync"
"time" "time"
"super-frpc/database"
"super-frpc/postLog"
) )
type TokenInfo struct { type TokenInfo struct {
@@ -20,12 +22,6 @@ type TokenInfo struct {
UserID int UserID int
} }
var (
tokenMap = make(map[int]*TokenInfo) // userID -> TokenInfo
tokenMux sync.RWMutex
tokenTTL = time.Hour
)
type Session struct { type Session struct {
ID string ID string
UserID int UserID int
@@ -34,13 +30,19 @@ type Session struct {
} }
var ( var (
sessionMap = make(map[string]*Session) // sessionID -> Session tokenMap = make(map[int]*TokenInfo)
tokenMux sync.RWMutex
tokenTTL = time.Hour
)
var (
sessionMap = make(map[string]*Session)
sessionMux sync.RWMutex sessionMux sync.RWMutex
sessionTTL = time.Hour sessionTTL = time.Hour
) )
var ( var (
sessionTokenMap = make(map[string]string) // sessionID -> token sessionTokenMap = make(map[string]string)
sessionTokenMux sync.RWMutex sessionTokenMux sync.RWMutex
) )
@@ -130,7 +132,7 @@ func GetTokenInfo(userID int) (*TokenInfo, error) {
return tokenInfo, nil return tokenInfo, nil
} }
func extractUserIDFromToken(token string) (int, error) { func ExtractUserIDFromToken(token string) (int, error) {
tokenMux.RLock() tokenMux.RLock()
defer tokenMux.RUnlock() defer tokenMux.RUnlock()
for userID, tokenInfo := range tokenMap { for userID, tokenInfo := range tokenMap {
@@ -178,13 +180,13 @@ func CleanupExpiredSessions() {
} }
} }
func hashPassword(password string) (string, error) { func HashPassword(password string) (string, error) {
hash := sha256.Sum256([]byte(password)) hash := sha256.Sum256([]byte(password))
return hex.EncodeToString(hash[:]), nil return hex.EncodeToString(hash[:]), nil
} }
func verifyPassword(password, hashedPassword string) bool { func VerifyPassword(password, hashedPassword string) bool {
hash, err := hashPassword(password) hash, err := HashPassword(password)
if err != nil { if err != nil {
postLog.Error(fmt.Sprintf("[verifyPassword] Failed to hash password: %v", err)) postLog.Error(fmt.Sprintf("[verifyPassword] Failed to hash password: %v", err))
return false return false
@@ -192,7 +194,7 @@ func verifyPassword(password, hashedPassword string) bool {
return hash == hashedPassword return hash == hashedPassword
} }
func isValidPassword(password string) bool { // Validate password complexity and generate hash func IsValidPassword(password string) bool {
if len(password) < 8 { if len(password) < 8 {
return false return false
} }
@@ -220,8 +222,8 @@ func isValidPassword(password string) bool { // Validate password complexity and
return hasUpper && hasLower && hasDigit && hasSpecial return hasUpper && hasLower && hasDigit && hasSpecial
} }
func ValidateTimeStamp(header http.Header) bool { func ValidateTimeStamp(header http.Header, debug bool) bool {
if is.debug { if debug {
return true return true
} }
@@ -243,7 +245,7 @@ func ValidateTimeStamp(header http.Header) bool {
} }
func GetUserIDFromToken(token string) (int, error) { func GetUserIDFromToken(token string) (int, error) {
userID, err := extractUserIDFromToken(token) userID, err := ExtractUserIDFromToken(token)
if err != nil { if err != nil {
return 0, fmt.Errorf("Failed to extract userID from token: %w", err) return 0, fmt.Errorf("Failed to extract userID from token: %w", err)
} }
@@ -259,7 +261,7 @@ func DeleteTokenInfo(userID int) error {
} }
func GetUsernameByID(userID int) string { func GetUsernameByID(userID int) string {
user, err := GetUserByID(userID) user, err := database.GetUserByID(userID)
if err != nil { if err != nil {
return "" return ""
} }
@@ -346,3 +348,24 @@ func ValidateTokenFromMap(token string) (int, error) {
} }
return 0, fmt.Errorf("invalid token") return 0, fmt.Errorf("invalid token")
} }
func GetSessionTokenMap() *sync.RWMutex {
return &sessionTokenMux
}
func GetSessionTokenMapSnapshot() map[string]string {
sessionTokenMux.RLock()
defer sessionTokenMux.RUnlock()
snapshot := make(map[string]string)
for k, v := range sessionTokenMap {
snapshot[k] = v
}
return snapshot
}
func GetSessionToken(sessionID string) (string, bool) {
sessionTokenMux.RLock()
defer sessionTokenMux.RUnlock()
token, exists := sessionTokenMap[sessionID]
return token, exists
}

View File

@@ -1,11 +1,16 @@
package main package user
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"super-frpc/database"
"super-frpc/global"
"super-frpc/handlers"
"super-frpc/postLog" "super-frpc/postLog"
"super-frpc/session"
) )
type RegisterRequest struct { type RegisterRequest struct {
@@ -45,20 +50,20 @@ type RemoveSessionRequest struct {
func RegisterHandler(w http.ResponseWriter, r *http.Request) { func RegisterHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
SendErrorResponse(w, http.StatusMethodNotAllowed, "Invalid request method") handlers.SendErrorResponse(w, http.StatusMethodNotAllowed, "Invalid request method")
postLog.Warning(fmt.Sprintf("[RegisterHandler] Invalid request method: %s", r.Method)) postLog.Warning(fmt.Sprintf("[RegisterHandler] Invalid request method: %s", r.Method))
return return
} }
if !ValidateTimeStamp(r.Header) { if !session.ValidateTimeStamp(r.Header, global.Is.Debug) {
SendErrorResponse(w, http.StatusBadRequest, "Invalid or missing X-Timestamp in header") handlers.SendErrorResponse(w, http.StatusBadRequest, "Invalid or missing X-Timestamp in header")
postLog.Warning(fmt.Sprintf("[RegisterHandler] Invalid or missing X-Timestamp in header: %s", r.Header.Get("X-Timestamp"))) postLog.Warning(fmt.Sprintf("[RegisterHandler] Invalid or missing X-Timestamp in header: %s", r.Header.Get("X-Timestamp")))
return return
} }
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusBadRequest, "Failed to read request body") handlers.SendErrorResponse(w, http.StatusBadRequest, "Failed to read request body")
postLog.Warning(fmt.Sprintf("[RegisterHandler] Failed to read request body: %v", err)) postLog.Warning(fmt.Sprintf("[RegisterHandler] Failed to read request body: %v", err))
return return
} }
@@ -66,33 +71,32 @@ func RegisterHandler(w http.ResponseWriter, r *http.Request) {
var req RegisterRequest var req RegisterRequest
if err := json.Unmarshal(body, &req); err != nil { if err := json.Unmarshal(body, &req); err != nil {
SendErrorResponse(w, http.StatusBadRequest, "Invalid request format") handlers.SendErrorResponse(w, http.StatusBadRequest, "Invalid request format")
postLog.Warning(fmt.Sprintf("[RegisterHandler] Invalid request format: %v", err)) postLog.Warning(fmt.Sprintf("[RegisterHandler] Invalid request format: %v", err))
return return
} }
if req.Username == "" || req.Passwd == "" { if req.Username == "" || req.Passwd == "" {
SendErrorResponse(w, http.StatusBadRequest, "Username and password are required") handlers.SendErrorResponse(w, http.StatusBadRequest, "Username and password are required")
postLog.Warning("[RegisterHandler] New user registration failed: username or password is empty") postLog.Warning("[RegisterHandler] New user registration failed: username or password is empty")
return return
} }
if !isValidInput(req.Username) || !isValidInput(req.Passwd) { if !isValidInput(req.Username) || !isValidInput(req.Passwd) {
SendErrorResponse(w, http.StatusBadRequest, "Invalid input: contains illegal characters") handlers.SendErrorResponse(w, http.StatusBadRequest, "Invalid input: contains illegal characters")
postLog.Debug(fmt.Sprintf("[RegisterHandler] New user registration failed: username or password contains illegal characters \"%s\":\"%s\"", req.Username, req.Passwd)) postLog.Debug(fmt.Sprintf("[RegisterHandler] New user registration failed: username or password contains illegal characters \"%s\":\"%s\"", req.Username, req.Passwd))
return return
} }
if !isValidPassword(req.Passwd) { if !session.IsValidPassword(req.Passwd) {
SendErrorResponse(w, http.StatusBadRequest, "Password does not meet complexity requirements (must contain uppercase, lowercase, digit, and special character)") handlers.SendErrorResponse(w, http.StatusBadRequest, "Password does not meet complexity requirements (must contain uppercase, lowercase, digit, and special character)")
postLog.Debug(fmt.Sprintf("[RegisterHandler] New user registration failed: password \"%s\" does not meet complexity requirements", req.Passwd)) postLog.Debug(fmt.Sprintf("[RegisterHandler] New user registration failed: password \"%s\" does not meet complexity requirements", req.Passwd))
return return
} }
// Check weather the user is the first user, if it is, give the superuser level userList, err := database.DBListUsers()
userList, err := DBListUsers()
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusInternalServerError, err.Error()) handlers.SendErrorResponse(w, http.StatusInternalServerError, err.Error())
postLog.Error(fmt.Sprintf("[RegisterHandler] Failed to list users: %v", err)) postLog.Error(fmt.Sprintf("[RegisterHandler] Failed to list users: %v", err))
return return
} }
@@ -103,21 +107,21 @@ func RegisterHandler(w http.ResponseWriter, r *http.Request) {
newUserType = "visitor" newUserType = "visitor"
} }
userID, err := AddUser(req.Username, req.Passwd, newUserType) userID, err := database.AddUser(req.Username, req.Passwd, newUserType, session.HashPassword, session.IsValidPassword)
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusInternalServerError, err.Error()) handlers.SendErrorResponse(w, http.StatusInternalServerError, err.Error())
postLog.Error(fmt.Sprintf("[RegisterHandler] Failed to register user \"%s\": %v", req.Username, err)) postLog.Error(fmt.Sprintf("[RegisterHandler] Failed to register user \"%s\": %v", req.Username, err))
return return
} }
user, err := GetUserByID(userID) user, err := database.GetUserByID(userID)
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusInternalServerError, "Failed to retrieve user after registration") handlers.SendErrorResponse(w, http.StatusInternalServerError, "Failed to retrieve user after registration")
postLog.Error(fmt.Sprintf("[RegisterHandler] Failed to retrieve user \"%s\" after registration: %v", req.Username, err)) postLog.Error(fmt.Sprintf("[RegisterHandler] Failed to retrieve user \"%s\" after registration: %v", req.Username, err))
return return
} }
SendSuccessResponse(w, "user registered successfully", map[string]interface{}{ handlers.SendSuccessResponse(w, "user registered successfully", map[string]interface{}{
"userID": user.UserID, "userID": user.UserID,
"username": user.Username, "username": user.Username,
"type": user.Type, "type": user.Type,
@@ -126,19 +130,19 @@ func RegisterHandler(w http.ResponseWriter, r *http.Request) {
func LoginHandler(w http.ResponseWriter, r *http.Request) { func LoginHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
SendErrorResponse(w, http.StatusMethodNotAllowed, "Invalid request method") handlers.SendErrorResponse(w, http.StatusMethodNotAllowed, "Invalid request method")
postLog.Warning(fmt.Sprintf("[LoginHandler] Invalid request method: %s", r.Method)) postLog.Warning(fmt.Sprintf("[LoginHandler] Invalid request method: %s", r.Method))
return return
} }
if !ValidateTimeStamp(r.Header) { if !session.ValidateTimeStamp(r.Header, global.Is.Debug) {
SendErrorResponse(w, http.StatusBadRequest, "Invalid or missing X-Timestamp in header") handlers.SendErrorResponse(w, http.StatusBadRequest, "Invalid or missing X-Timestamp in header")
return return
} }
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusBadRequest, "Failed to read request body") handlers.SendErrorResponse(w, http.StatusBadRequest, "Failed to read request body")
postLog.Warning(fmt.Sprintf("[LoginHandler] Failed to read request body: %v", err)) postLog.Warning(fmt.Sprintf("[LoginHandler] Failed to read request body: %v", err))
return return
} }
@@ -146,57 +150,57 @@ func LoginHandler(w http.ResponseWriter, r *http.Request) {
var req LoginRequest var req LoginRequest
if err := json.Unmarshal(body, &req); err != nil { if err := json.Unmarshal(body, &req); err != nil {
SendErrorResponse(w, http.StatusBadRequest, "Invalid request format") handlers.SendErrorResponse(w, http.StatusBadRequest, "Invalid request format")
postLog.Warning(fmt.Sprintf("[LoginHandler] Invalid request format: %v", err)) postLog.Warning(fmt.Sprintf("[LoginHandler] Invalid request format: %v", err))
return return
} }
if req.Username == "" || req.Passwd == "" { if req.Username == "" || req.Passwd == "" {
SendErrorResponse(w, http.StatusBadRequest, "Username and password are required") handlers.SendErrorResponse(w, http.StatusBadRequest, "Username and password are required")
postLog.Warning("[LoginHandler] Login failed: username or password is empty") postLog.Warning("[LoginHandler] Login failed: username or password is empty")
return return
} }
if !isValidInput(req.Username) || !isValidInput(req.Passwd) { if !isValidInput(req.Username) || !isValidInput(req.Passwd) {
SendErrorResponse(w, http.StatusBadRequest, "Invalid input: contains illegal characters") handlers.SendErrorResponse(w, http.StatusBadRequest, "Invalid input: contains illegal characters")
postLog.Debug(fmt.Sprintf("[LoginHandler] Login failed: username or password contains illegal characters \"%s\":\"%s\"", req.Username, req.Passwd)) postLog.Debug(fmt.Sprintf("[LoginHandler] Login failed: username or password contains illegal characters \"%s\":\"%s\"", req.Username, req.Passwd))
return return
} }
user, err := GetUserByUsername(req.Username) user, err := database.GetUserByUsername(req.Username)
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusUnauthorized, "User not exist") handlers.SendErrorResponse(w, http.StatusUnauthorized, "User not exist")
postLog.Warning(fmt.Sprintf("[LoginHandler] Login failed: User not exist \"%s\"", req.Username)) postLog.Warning(fmt.Sprintf("[LoginHandler] Login failed: User not exist \"%s\"", req.Username))
return return
} }
if !verifyPassword(req.Passwd, user.Passwd) { if !session.VerifyPassword(req.Passwd, user.Passwd) {
SendErrorResponse(w, http.StatusUnauthorized, "Invalid password") handlers.SendErrorResponse(w, http.StatusUnauthorized, "Invalid password")
postLog.Warning(fmt.Sprintf("[LoginHandler] Login failed: invalid password for user \"%s\"", req.Username)) postLog.Warning(fmt.Sprintf("[LoginHandler] Login failed: invalid password for user \"%s\"", req.Username))
return return
} }
existingTokenInfo, err := GetTokenInfo(user.UserID) existingTokenInfo, err := session.GetTokenInfo(user.UserID)
if err == nil && existingTokenInfo != nil { if err == nil && existingTokenInfo != nil {
SendErrorResponse(w, http.StatusConflict, "User is already logged in") handlers.SendErrorResponse(w, http.StatusConflict, "User is already logged in")
postLog.Warning(fmt.Sprintf("[LoginHandler] Login failed: user \"%s\" is already logged in", req.Username)) postLog.Warning(fmt.Sprintf("[LoginHandler] Login failed: user \"%s\" is already logged in", req.Username))
return return
} }
token, err := GenerateToken(user.UserID) token, err := session.GenerateToken(user.UserID)
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusInternalServerError, "Failed to generate token") handlers.SendErrorResponse(w, http.StatusInternalServerError, "Failed to generate token")
postLog.Error(fmt.Sprintf("[LoginHandler] Failed to generate token for user \"%s\": %v", req.Username, err)) postLog.Error(fmt.Sprintf("[LoginHandler] Failed to generate token for user \"%s\": %v", req.Username, err))
return return
} }
if err := JoinSession(user.UserID, user.Username, token); err != nil { if err := session.JoinSession(user.UserID, user.Username, token); err != nil {
SendErrorResponse(w, http.StatusInternalServerError, "Failed to create session") handlers.SendErrorResponse(w, http.StatusInternalServerError, "Failed to create session")
postLog.Error(fmt.Sprintf("[LoginHandler] Failed to create session for user \"%s\": %v", req.Username, err)) postLog.Error(fmt.Sprintf("[LoginHandler] Failed to create session for user \"%s\": %v", req.Username, err))
return return
} }
SendSuccessResponse(w, "Login successful", map[string]interface{}{ handlers.SendSuccessResponse(w, "Login successful", map[string]interface{}{
"token": token, "token": token,
"userID": user.UserID, "userID": user.UserID,
"username": user.Username, "username": user.Username,
@@ -206,16 +210,17 @@ func LoginHandler(w http.ResponseWriter, r *http.Request) {
} }
func LogoutHandler(w http.ResponseWriter, r *http.Request) { func LogoutHandler(w http.ResponseWriter, r *http.Request) {
userID, err := Auth(w, r, http.MethodGet) userID, err := handlers.Auth(w, r, http.MethodGet)
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusUnauthorized, err.Error()) handlers.SendErrorResponse(w, http.StatusUnauthorized, err.Error())
postLog.Warning(fmt.Sprintf("[LogoutHandler] Auth failed: %v", err)) postLog.Warning(fmt.Sprintf("[LogoutHandler] Auth failed: %v", err))
return return
} }
sessionTokenMux := session.GetSessionTokenMap()
sessionTokenMux.RLock() sessionTokenMux.RLock()
sessionID := "" sessionID := ""
for sid, token := range sessionTokenMap { for sid, token := range session.GetSessionTokenMapSnapshot() {
if token == r.Header.Get("X-Token") { if token == r.Header.Get("X-Token") {
sessionID = sid sessionID = sid
break break
@@ -224,32 +229,32 @@ func LogoutHandler(w http.ResponseWriter, r *http.Request) {
sessionTokenMux.RUnlock() sessionTokenMux.RUnlock()
if sessionID == "" { if sessionID == "" {
SendErrorResponse(w, http.StatusNotFound, "Session not found for token") handlers.SendErrorResponse(w, http.StatusNotFound, "Session not found for token")
postLog.Warning(fmt.Sprintf("[LogoutHandler] Session not found for token from user [%d]%s", userID, GetUsernameByID(userID))) postLog.Warning(fmt.Sprintf("[LogoutHandler] Session not found for token from user [%d]%s", userID, session.GetUsernameByID(userID)))
return return
} }
if err := RemoveSession(sessionID); err != nil { if err := session.RemoveSession(sessionID); err != nil {
SendErrorResponse(w, http.StatusInternalServerError, "Failed to logout") handlers.SendErrorResponse(w, http.StatusInternalServerError, "Failed to logout")
postLog.Error(fmt.Sprintf("[LogoutHandler] Failed to logout user [%d]%s: %v", userID, GetUsernameByID(userID), err)) postLog.Error(fmt.Sprintf("[LogoutHandler] Failed to logout user [%d]%s: %v", userID, session.GetUsernameByID(userID), err))
return return
} }
SendSuccessResponse(w, "Logout successful", nil) handlers.SendSuccessResponse(w, "Logout successful", nil)
postLog.Info(fmt.Sprintf("[LogoutHandler] User [%d]%s Logout successful", userID, GetUsernameByID(userID))) postLog.Info(fmt.Sprintf("[LogoutHandler] User [%d]%s Logout successful", userID, session.GetUsernameByID(userID)))
} }
func RemoveSessionHandler(w http.ResponseWriter, r *http.Request) { func RemoveSessionHandler(w http.ResponseWriter, r *http.Request) {
userID, err := Auth(w, r, http.MethodPost, "superuser") userID, err := handlers.Auth(w, r, http.MethodPost, "superuser")
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusUnauthorized, err.Error()) handlers.SendErrorResponse(w, http.StatusUnauthorized, err.Error())
postLog.Warning(fmt.Sprintf("[RemoveSessionHandler] Auth failed: %v", err)) postLog.Warning(fmt.Sprintf("[RemoveSessionHandler] Auth failed: %v", err))
return return
} }
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusBadRequest, "Failed to read request body") handlers.SendErrorResponse(w, http.StatusBadRequest, "Failed to read request body")
postLog.Warning(fmt.Sprintf("[RemoveSessionHandler] Failed to read request body: %v", err)) postLog.Warning(fmt.Sprintf("[RemoveSessionHandler] Failed to read request body: %v", err))
return return
} }
@@ -257,38 +262,38 @@ func RemoveSessionHandler(w http.ResponseWriter, r *http.Request) {
var req RemoveSessionRequest var req RemoveSessionRequest
if err := json.Unmarshal(body, &req); err != nil { if err := json.Unmarshal(body, &req); err != nil {
SendErrorResponse(w, http.StatusBadRequest, "Invalid request format") handlers.SendErrorResponse(w, http.StatusBadRequest, "Invalid request format")
postLog.Warning(fmt.Sprintf("[RemoveSessionHandler] Invalid request format: %v", err)) postLog.Warning(fmt.Sprintf("[RemoveSessionHandler] Invalid request format: %v", err))
return return
} }
if req.SessionID == "" { if req.SessionID == "" {
SendErrorResponse(w, http.StatusBadRequest, "SessionID is required") handlers.SendErrorResponse(w, http.StatusBadRequest, "SessionID is required")
postLog.Warning("[RemoveSessionHandler] SessionID is empty") postLog.Warning("[RemoveSessionHandler] SessionID is empty")
return return
} }
if err := RemoveSession(req.SessionID); err != nil { if err := session.RemoveSession(req.SessionID); err != nil {
SendErrorResponse(w, http.StatusInternalServerError, fmt.Sprintf("Failed to remove session: %v", err)) handlers.SendErrorResponse(w, http.StatusInternalServerError, fmt.Sprintf("Failed to remove session: %v", err))
postLog.Error(fmt.Sprintf("[RemoveSessionHandler] Failed to remove session %s: %v", req.SessionID, err)) postLog.Error(fmt.Sprintf("[RemoveSessionHandler] Failed to remove session %s: %v", req.SessionID, err))
return return
} }
postLog.Info(fmt.Sprintf("[RemoveSessionHandler] User [%d]%s removed session %s", userID, GetUsernameByID(userID), req.SessionID)) postLog.Info(fmt.Sprintf("[RemoveSessionHandler] User [%d]%s removed session %s", userID, session.GetUsernameByID(userID), req.SessionID))
SendSuccessResponse(w, "Session removed successfully", nil) handlers.SendSuccessResponse(w, "Session removed successfully", nil)
} }
func CreateUserHandler(w http.ResponseWriter, r *http.Request) { func CreateUserHandler(w http.ResponseWriter, r *http.Request) {
_, err := Auth(w, r, http.MethodPost, "superuser") _, err := handlers.Auth(w, r, http.MethodPost, "superuser")
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusUnauthorized, err.Error()) handlers.SendErrorResponse(w, http.StatusUnauthorized, err.Error())
postLog.Warning(fmt.Sprintf("[CreateUserHandler] Auth failed: %v", err)) postLog.Warning(fmt.Sprintf("[CreateUserHandler] Auth failed: %v", err))
return return
} }
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusBadRequest, "Failed to read request body") handlers.SendErrorResponse(w, http.StatusBadRequest, "Failed to read request body")
postLog.Warning(fmt.Sprintf("[CreateUserHandler] Failed to read request body: %v", err)) postLog.Warning(fmt.Sprintf("[CreateUserHandler] Failed to read request body: %v", err))
return return
} }
@@ -296,38 +301,38 @@ func CreateUserHandler(w http.ResponseWriter, r *http.Request) {
var req CreateUserRequest var req CreateUserRequest
if err := json.Unmarshal(body, &req); err != nil { if err := json.Unmarshal(body, &req); err != nil {
SendErrorResponse(w, http.StatusBadRequest, "Invalid request format") handlers.SendErrorResponse(w, http.StatusBadRequest, "Invalid request format")
postLog.Warning(fmt.Sprintf("[CreateUserHandler] Invalid request format: %v", err)) postLog.Warning(fmt.Sprintf("[CreateUserHandler] Invalid request format: %v", err))
return return
} }
if req.Username == "" || req.Passwd == "" || req.Type == "" { if req.Username == "" || req.Passwd == "" || req.Type == "" {
SendErrorResponse(w, http.StatusBadRequest, "Username, password, and type are required") handlers.SendErrorResponse(w, http.StatusBadRequest, "Username, password, and type are required")
postLog.Warning("[CreateUserHandler] CreateUser failed: username, password, or type is empty") postLog.Warning("[CreateUserHandler] CreateUser failed: username, password, or type is empty")
return return
} }
if req.Type != "admin" && req.Type != "user" && req.Type != "superuser" { if req.Type != "admin" && req.Type != "user" && req.Type != "superuser" {
SendErrorResponse(w, http.StatusBadRequest, "Invalid type: must be 'admin' or 'user' or 'superuser'") handlers.SendErrorResponse(w, http.StatusBadRequest, "Invalid type: must be 'admin' or 'user' or 'superuser'")
postLog.Warning(fmt.Sprintf("[CreateUserHandler] CreateUser failed: invalid type: %s", req.Type)) postLog.Warning(fmt.Sprintf("[CreateUserHandler] CreateUser failed: invalid type: %s", req.Type))
return return
} }
userID, err := AddUser(req.Username, req.Passwd, req.Type) userID, err := database.AddUser(req.Username, req.Passwd, req.Type, session.HashPassword, session.IsValidPassword)
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusInternalServerError, err.Error()) handlers.SendErrorResponse(w, http.StatusInternalServerError, err.Error())
postLog.Error(fmt.Sprintf("[RegisterHandler] Failed to register user \"%s\": %v", req.Username, err)) postLog.Error(fmt.Sprintf("[RegisterHandler] Failed to register user \"%s\": %v", req.Username, err))
return return
} }
user, err := GetUserByID(userID) user, err := database.GetUserByID(userID)
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusInternalServerError, "Failed to retrieve user after creation") handlers.SendErrorResponse(w, http.StatusInternalServerError, "Failed to retrieve user after creation")
postLog.Error(fmt.Sprintf("[CreateUserHandler] Failed to retrieve user \"%s\" after creation: %v", req.Username, err)) postLog.Error(fmt.Sprintf("[CreateUserHandler] Failed to retrieve user \"%s\" after creation: %v", req.Username, err))
return return
} }
SendSuccessResponse(w, "User created successfully", map[string]interface{}{ handlers.SendSuccessResponse(w, "User created successfully", map[string]interface{}{
"userID": user.UserID, "userID": user.UserID,
"username": user.Username, "username": user.Username,
"type": user.Type, "type": user.Type,
@@ -335,16 +340,16 @@ func CreateUserHandler(w http.ResponseWriter, r *http.Request) {
postLog.Info(fmt.Sprintf("[CreateUserHandler] User \"%s\" created successfully with ID: %d", req.Username, userID)) postLog.Info(fmt.Sprintf("[CreateUserHandler] User \"%s\" created successfully with ID: %d", req.Username, userID))
} }
func ModifyUserHandler (w http.ResponseWriter, r *http.Request) { func ModifyUserHandler(w http.ResponseWriter, r *http.Request) {
_, err := Auth(w, r, http.MethodPost, "user", "admin", "superuser") _, err := handlers.Auth(w, r, http.MethodPost, "user", "admin", "superuser")
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusUnauthorized, fmt.Sprintf("Auth failed: %v", err)) handlers.SendErrorResponse(w, http.StatusUnauthorized, fmt.Sprintf("Auth failed: %v", err))
return return
} }
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusBadRequest, "Failed to read request body") handlers.SendErrorResponse(w, http.StatusBadRequest, "Failed to read request body")
postLog.Warning(fmt.Sprintf("[ModifyUserHandler] Failed to read request body: %v", err)) postLog.Warning(fmt.Sprintf("[ModifyUserHandler] Failed to read request body: %v", err))
return return
} }
@@ -353,43 +358,43 @@ func ModifyUserHandler (w http.ResponseWriter, r *http.Request) {
var req ModifyUserRequest var req ModifyUserRequest
if err := json.Unmarshal(body, &req); err != nil { if err := json.Unmarshal(body, &req); err != nil {
SendErrorResponse(w, http.StatusBadRequest, "Invalid request format") handlers.SendErrorResponse(w, http.StatusBadRequest, "Invalid request format")
postLog.Warning(fmt.Sprintf("[ModifyUserHandler] Invalid request format: %v", err)) postLog.Warning(fmt.Sprintf("[ModifyUserHandler] Invalid request format: %v", err))
return return
} }
if req.UserID == 0 { if req.UserID == 0 {
SendErrorResponse(w, http.StatusBadRequest, "UserID is required") handlers.SendErrorResponse(w, http.StatusBadRequest, "UserID is required")
postLog.Warning("[ModifyUserHandler] ModifyUser failed: UserID is empty") postLog.Warning("[ModifyUserHandler] ModifyUser failed: UserID is empty")
return return
} }
if req.Username == "" { if req.Username == "" {
SendErrorResponse(w, http.StatusBadRequest, "Username is required") handlers.SendErrorResponse(w, http.StatusBadRequest, "Username is required")
postLog.Warning("[ModifyUserHandler] ModifyUser failed: username is empty") postLog.Warning("[ModifyUserHandler] ModifyUser failed: username is empty")
return return
} }
if err := DBUpdateUser(req.UserID, req.Username, req.Passwd); err != nil { if err := database.DBUpdateUser(req.UserID, req.Username, req.Passwd); err != nil {
SendErrorResponse(w, http.StatusInternalServerError, err.Error()) handlers.SendErrorResponse(w, http.StatusInternalServerError, err.Error())
postLog.Error(fmt.Sprintf("[ModifyUserHandler] Failed to update user [%d]: %v", req.UserID, err)) postLog.Error(fmt.Sprintf("[ModifyUserHandler] Failed to update user [%d]: %v", req.UserID, err))
return return
} }
SendSuccessResponse(w, "User updated successfully", nil) handlers.SendSuccessResponse(w, "User updated successfully", nil)
postLog.Info(fmt.Sprintf("[ModifyUserHandler] User [%d]%s updated successfully to: %s", req.UserID, GetUsernameByID(req.UserID), req.Username)) postLog.Info(fmt.Sprintf("[ModifyUserHandler] User [%d]%s updated successfully to: %s", req.UserID, session.GetUsernameByID(req.UserID), req.Username))
} }
func ModifyUserTypeHandler (w http.ResponseWriter, r *http.Request) { func ModifyUserTypeHandler(w http.ResponseWriter, r *http.Request) {
_, err := Auth(w, r, http.MethodPost, "superuser") _, err := handlers.Auth(w, r, http.MethodPost, "superuser")
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusUnauthorized, err.Error()) handlers.SendErrorResponse(w, http.StatusUnauthorized, err.Error())
postLog.Warning(fmt.Sprintf("[ModifyUserTypeHandler] Auth failed: %v", err)) postLog.Warning(fmt.Sprintf("[ModifyUserTypeHandler] Auth failed: %v", err))
return return
} }
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusBadRequest, "Failed to read request body") handlers.SendErrorResponse(w, http.StatusBadRequest, "Failed to read request body")
postLog.Warning(fmt.Sprintf("[ModifyUserTypeHandler] Failed to read request body: %v", err)) postLog.Warning(fmt.Sprintf("[ModifyUserTypeHandler] Failed to read request body: %v", err))
return return
} }
@@ -398,40 +403,40 @@ func ModifyUserTypeHandler (w http.ResponseWriter, r *http.Request) {
var req ModifyUserTypeRequest var req ModifyUserTypeRequest
if err := json.Unmarshal(body, &req); err != nil { if err := json.Unmarshal(body, &req); err != nil {
SendErrorResponse(w, http.StatusBadRequest, "Invalid request format") handlers.SendErrorResponse(w, http.StatusBadRequest, "Invalid request format")
postLog.Warning(fmt.Sprintf("[ModifyUserTypeHandler] Invalid request format: %v", err)) postLog.Warning(fmt.Sprintf("[ModifyUserTypeHandler] Invalid request format: %v", err))
return return
} }
if req.UserID == 0 { if req.UserID == 0 {
SendErrorResponse(w, http.StatusBadRequest, "UserID is required") handlers.SendErrorResponse(w, http.StatusBadRequest, "UserID is required")
postLog.Warning("[ModifyUserTypeHandler] ModifyUserType failed: UserID is empty") postLog.Warning("[ModifyUserTypeHandler] ModifyUserType failed: UserID is empty")
return return
} }
if req.Type != "admin" && req.Type != "visitor" && req.Type != "superuser" { if req.Type != "admin" && req.Type != "visitor" && req.Type != "superuser" {
SendErrorResponse(w, http.StatusBadRequest, "Invalid type: must be 'admin' or 'visitor' or 'superuser'") handlers.SendErrorResponse(w, http.StatusBadRequest, "Invalid type: must be 'admin' or 'visitor' or 'superuser'")
postLog.Warning(fmt.Sprintf("[ModifyUserTypeHandler] ModifyUserType failed: invalid type: %s", req.Type)) postLog.Warning(fmt.Sprintf("[ModifyUserTypeHandler] ModifyUserType failed: invalid type: %s", req.Type))
return return
} }
if err := DBUpdateUserType(req.UserID, req.Type); err != nil { if err := database.DBUpdateUserType(req.UserID, req.Type); err != nil {
SendErrorResponse(w, http.StatusInternalServerError, err.Error()) handlers.SendErrorResponse(w, http.StatusInternalServerError, err.Error())
postLog.Error(fmt.Sprintf("[ModifyUserTypeHandler] Failed to update user type [%d]: %v", req.UserID, err)) postLog.Error(fmt.Sprintf("[ModifyUserTypeHandler] Failed to update user type [%d]: %v", req.UserID, err))
return return
} }
SendSuccessResponse(w, "User type updated successfully", nil) handlers.SendSuccessResponse(w, "User type updated successfully", nil)
postLog.Info(fmt.Sprintf("[ModifyUserTypeHandler] User [%d]%s type updated successfully to: %s", req.UserID, GetUsernameByID(req.UserID), req.Type)) postLog.Info(fmt.Sprintf("[ModifyUserTypeHandler] User [%d]%s type updated successfully to: %s", req.UserID, session.GetUsernameByID(req.UserID), req.Type))
} }
func RemoveUserHandler(w http.ResponseWriter, r *http.Request) { func RemoveUserHandler(w http.ResponseWriter, r *http.Request) {
_, err := Auth(w, r, http.MethodPost, "superuser") _, err := handlers.Auth(w, r, http.MethodPost, "superuser")
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusUnauthorized, err.Error()) handlers.SendErrorResponse(w, http.StatusUnauthorized, err.Error())
postLog.Warning(fmt.Sprintf("[RemoveUserHandler] Auth failed: %v", err)) postLog.Warning(fmt.Sprintf("[RemoveUserHandler] Auth failed: %v", err))
return return
} }
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusBadRequest, "Failed to read request body") handlers.SendErrorResponse(w, http.StatusBadRequest, "Failed to read request body")
postLog.Warning(fmt.Sprintf("[RemoveUserHandler] Failed to read request body: %v", err)) postLog.Warning(fmt.Sprintf("[RemoveUserHandler] Failed to read request body: %v", err))
return return
} }
@@ -439,49 +444,53 @@ func RemoveUserHandler(w http.ResponseWriter, r *http.Request) {
var req RemoveUserRequest var req RemoveUserRequest
if err := json.Unmarshal(body, &req); err != nil { if err := json.Unmarshal(body, &req); err != nil {
SendErrorResponse(w, http.StatusBadRequest, "Invalid request format") handlers.SendErrorResponse(w, http.StatusBadRequest, "Invalid request format")
postLog.Warning(fmt.Sprintf("[RemoveUserHandler] Invalid request format: %v", err)) postLog.Warning(fmt.Sprintf("[RemoveUserHandler] Invalid request format: %v", err))
return return
} }
if req.TargetUserID == 0 { if req.TargetUserID == 0 {
SendErrorResponse(w, http.StatusBadRequest, "TargetUserID is required") handlers.SendErrorResponse(w, http.StatusBadRequest, "TargetUserID is required")
postLog.Warning("[RemoveUserHandler] RemoveUser failed: TargetUserID is empty") postLog.Warning("[RemoveUserHandler] RemoveUser failed: TargetUserID is empty")
return return
} }
if err := RemoveUser(req.TargetUserID); err != nil { if err := database.RemoveUser(req.TargetUserID); err != nil {
SendErrorResponse(w, http.StatusInternalServerError, err.Error()) handlers.SendErrorResponse(w, http.StatusInternalServerError, err.Error())
postLog.Error(fmt.Sprintf("[RemoveUserHandler] Failed to remove user [%d]: %v", req.TargetUserID, err)) postLog.Error(fmt.Sprintf("[RemoveUserHandler] Failed to remove user [%d]: %v", req.TargetUserID, err))
return return
} }
SendSuccessResponse(w, "User removed successfully", nil) handlers.SendSuccessResponse(w, "User removed successfully", nil)
} }
func ListUserHandler(w http.ResponseWriter, r *http.Request) { func ListUserHandler(w http.ResponseWriter, r *http.Request) {
_, err := Auth(w, r, http.MethodGet, "superuser") _, err := handlers.Auth(w, r, http.MethodGet, "superuser")
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusUnauthorized, err.Error()) handlers.SendErrorResponse(w, http.StatusUnauthorized, err.Error())
postLog.Warning(fmt.Sprintf("[ListUserHandler] Auth failed: %v", err)) postLog.Warning(fmt.Sprintf("[ListUserHandler] Auth failed: %v", err))
return return
} }
userList, err := DBListUsers() userList, err := database.DBListUsers()
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusInternalServerError, "Failed to list users") handlers.SendErrorResponse(w, http.StatusInternalServerError, "Failed to list users")
postLog.Error(fmt.Sprintf("[ListUserHandler] Failed to list users: %v", err)) postLog.Error(fmt.Sprintf("[ListUserHandler] Failed to list users: %v", err))
return return
} }
SendSuccessResponse(w, "User list retrieved successfully", userList) handlers.SendSuccessResponse(w, "User list retrieved successfully", userList)
} }
func ListActiveSessionsHandler(w http.ResponseWriter, r *http.Request) { func ListActiveSessionsHandler(w http.ResponseWriter, r *http.Request) {
userID, err := Auth(w, r, http.MethodGet, "superuser", "admin") userID, err := handlers.Auth(w, r, http.MethodGet, "superuser", "admin")
if err != nil { if err != nil {
SendErrorResponse(w, http.StatusUnauthorized, err.Error()) handlers.SendErrorResponse(w, http.StatusUnauthorized, err.Error())
postLog.Warning(fmt.Sprintf("[ListActiveSessionsHandler] Auth failed: %v", err)) postLog.Warning(fmt.Sprintf("[ListActiveSessionsHandler] Auth failed: %v", err))
return return
} }
sessions := ListActiveSessions() sessions := session.ListActiveSessions()
postLog.Debug(fmt.Sprintf("[ListActiveSessionsHandler] User [%d]%s listed %d active sessions", userID, GetUsernameByID(userID), len(sessions))) postLog.Debug(fmt.Sprintf("[ListActiveSessionsHandler] User [%d]%s listed %d active sessions", userID, session.GetUsernameByID(userID), len(sessions)))
SendSuccessResponse(w, "Active sessions listed", sessions) handlers.SendSuccessResponse(w, "Active sessions listed", sessions)
}
func isValidInput(input string) bool {
return database.IsValidInput(input)
} }