- Add `getSystemdServiceName` function to get systemd service name. - When detected as systemd sevice, use `systemctl restart` to restart the program then exit.
237 lines
5.9 KiB
Go
237 lines
5.9 KiB
Go
package utils
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
"path/filepath"
|
|
|
|
"super-frpc/database"
|
|
"super-frpc/global"
|
|
"super-frpc/postLog"
|
|
"super-frpc/session"
|
|
)
|
|
|
|
type Response struct {
|
|
Success bool `json:"success"`
|
|
Message string `json:"message,omitempty"`
|
|
Data interface{} `json:"data,omitempty"`
|
|
}
|
|
|
|
// SendErrorResponse sends an error response with the specified status code and message.
|
|
// If data is provided, it is included in the response.
|
|
func SendErrorResponse(w http.ResponseWriter, statusCode int, message string, data ...interface{}) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(statusCode)
|
|
resp := Response{
|
|
Success: false,
|
|
Message: message,
|
|
}
|
|
if len(data) > 0 {
|
|
resp.Data = data[0]
|
|
}
|
|
jsonResp, err := json.Marshal(resp)
|
|
if err != nil {
|
|
postLog.Error(fmt.Sprintf("Failed to marshal error response: %v", err))
|
|
return
|
|
}
|
|
w.Write(jsonResp)
|
|
}
|
|
|
|
// SendSuccessResponse sends a success response with the specified message and data.
|
|
// If data is provided, it is included in the response.
|
|
func SendSuccessResponse(w http.ResponseWriter, message string, data interface{}) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
resp := Response{
|
|
Success: true,
|
|
Message: message,
|
|
Data: data,
|
|
}
|
|
jsonResp, err := json.Marshal(resp)
|
|
if err != nil {
|
|
postLog.Error(fmt.Sprintf("Failed to marshal success response: %v", err))
|
|
return
|
|
}
|
|
w.Write(jsonResp)
|
|
}
|
|
|
|
// Auth authenticates the request.
|
|
// It checks the request method, timestamp, and token.
|
|
// If any of these checks fail, it returns an error.
|
|
// If no error occurs, it returns the user ID.
|
|
func Auth(w http.ResponseWriter, r *http.Request, targetMethod string, allowedUserLevels ...string) (int, error) {
|
|
if r.Method != targetMethod {
|
|
return 0, fmt.Errorf("Method not allowed: %s", targetMethod)
|
|
}
|
|
|
|
if !global.Is.Debug && !session.ValidateTimeStamp(r.Header, global.Is.Debug) {
|
|
return 0, fmt.Errorf("Invalid or missing X-Timestamp in header")
|
|
}
|
|
|
|
userID, err := session.ExtractUserIDFromToken(r.Header.Get("X-Token"))
|
|
if err != nil {
|
|
return 0, fmt.Errorf("Invalid token format: %w", err)
|
|
}
|
|
|
|
if err := session.ValidateToken(userID, r.Header.Get("X-Token")); err != nil {
|
|
return 0, fmt.Errorf("Token validation failed: %w", err)
|
|
}
|
|
|
|
if len(allowedUserLevels) > 0 {
|
|
currentUser, err := database.DBQuerySpecificUser(userID)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("Failed to query user: %w", err)
|
|
}
|
|
allowed := false
|
|
for _, level := range allowedUserLevels {
|
|
if currentUser.Type == level {
|
|
allowed = true
|
|
break
|
|
}
|
|
}
|
|
if !allowed {
|
|
return 0, fmt.Errorf("User level not allowed: required one of %v, got %s", allowedUserLevels, currentUser.Type)
|
|
}
|
|
}
|
|
|
|
return userID, nil
|
|
}
|
|
|
|
func GetUserType(userID int) (string, error) {
|
|
user, err := database.GetUserByID(userID)
|
|
if err != nil {
|
|
return "", fmt.Errorf("Failed to get user type: %w", err)
|
|
}
|
|
return user.Type, nil
|
|
}
|
|
|
|
func GetClientIP(r *http.Request) string {
|
|
forwarded := r.Header.Get("X-Forwarded-For")
|
|
if forwarded != "" {
|
|
return forwarded
|
|
}
|
|
return r.RemoteAddr
|
|
}
|
|
|
|
func LogRequest(r *http.Request, userID int) {
|
|
postLog.Info(fmt.Sprintf("[%s] %s %s - UserID: %d - IP: %s",
|
|
time.Now().Format("2006-01-02 15:04:05"),
|
|
r.Method,
|
|
r.URL.Path,
|
|
userID,
|
|
GetClientIP(r),
|
|
))
|
|
}
|
|
|
|
func IntToString(i int) string {
|
|
return strconv.Itoa(i)
|
|
}
|
|
|
|
func GetTextMiddle(text, left, right string) string {
|
|
start := 0
|
|
if left != "" {
|
|
idx := strings.Index(text, left)
|
|
if idx == -1 {
|
|
return ""
|
|
}
|
|
start = idx + len(left)
|
|
}
|
|
|
|
if right == "" {
|
|
if start > len(text) {
|
|
return ""
|
|
}
|
|
return text[start:]
|
|
}
|
|
|
|
endIdx := strings.Index(text[start:], right)
|
|
if endIdx == -1 {
|
|
return ""
|
|
}
|
|
return text[start : start+endIdx]
|
|
}
|
|
|
|
func GetCmdType(source string) string {
|
|
return GetTextMiddle(source, "[", "]")
|
|
}
|
|
|
|
func GetCmdParams(source string, param string) string {
|
|
return GetTextMiddle(source, "<"+param+">", "</"+param+">")
|
|
}
|
|
|
|
func IsFileExist(filePath string) bool {
|
|
_, err := os.Stat(filePath)
|
|
return !os.IsNotExist(err)
|
|
}
|
|
|
|
var pendingSystemdRestart string
|
|
|
|
func RestartHandler(w http.ResponseWriter, r *http.Request) {
|
|
userID, err := Auth(w, r, http.MethodGet, "superuser", "admin")
|
|
if err != nil {
|
|
SendErrorResponse(w, http.StatusUnauthorized, "invalid token or timestamp")
|
|
postLog.Warning(fmt.Sprintf("[RestartHandler] Auth failed: %v, userID: %d", err, userID))
|
|
return
|
|
}
|
|
err = RestartProcess()
|
|
if err != nil {
|
|
SendErrorResponse(w, http.StatusInternalServerError, fmt.Sprintf("failed to restart: %v", err))
|
|
postLog.Error(fmt.Sprintf("[RestartHandler] [%d] failed to restart: %v", userID, err))
|
|
return
|
|
}
|
|
SendSuccessResponse(w, "Restarting server...", nil)
|
|
postLog.Info(fmt.Sprintf("[RestartHandler] [%d] Restarting server", userID))
|
|
|
|
if pendingSystemdRestart != "" {
|
|
serviceName := pendingSystemdRestart
|
|
go func() {
|
|
time.Sleep(500 * time.Millisecond)
|
|
cmd := exec.Command("systemctl", "restart", serviceName)
|
|
if err := cmd.Run(); err != nil {
|
|
postLog.Error(fmt.Sprintf("[RestartHandler] [%d] failed to restart systemd service %s: %v", userID, serviceName, err))
|
|
}
|
|
}()
|
|
}
|
|
}
|
|
|
|
func RestartProcess() error {
|
|
if getInitSystem() == "systemd" {
|
|
serviceName, err := getSystemdServiceName()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get systemd service name: %w", err)
|
|
}
|
|
pendingSystemdRestart = serviceName
|
|
return nil
|
|
}
|
|
|
|
executable, err := os.Executable()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get executable path: %w", err)
|
|
}
|
|
cmd := exec.Command(executable, os.Args[1:]...)
|
|
cmd.Stdout = os.Stdout
|
|
cmd.Stderr = os.Stderr
|
|
if err := cmd.Start(); err != nil {
|
|
return fmt.Errorf("failed to start new process: %w", err)
|
|
}
|
|
os.Exit(0)
|
|
return nil
|
|
}
|
|
|
|
func GetCurrentPath() (string) {
|
|
filePath, _ := exec.LookPath(os.Args[0])
|
|
if filePath == "" {
|
|
return ""
|
|
}
|
|
absPath, err := filepath.Abs(filePath)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
return filepath.Dir(absPath)
|
|
} |