diff --git a/config.go b/config.go index 6eab2d8..91e7e2d 100644 --- a/config.go +++ b/config.go @@ -1,6 +1,5 @@ package main - import ( "Watchdog_Linux-systemd/postLog" "encoding/json" @@ -10,10 +9,6 @@ import ( var Config struct { DebugMode bool `json:"debugMode"` - Debug struct { - ListenAddr string `json:"listenAddr"` - ListenPort int `json:"listenPort"` - } `json:"debug"` } func loadConfig() { @@ -26,19 +21,9 @@ func loadConfig() { if err := decoder.Decode(&Config); err != nil { postLog.Fatal(fmt.Sprintf("Failed to decode config file: %v, err: %v", configFile, err)) } - + if Config.DebugMode { isDebug = true } - if Config.Debug.ListenAddr != "" { - DebugListenAddr = Config.Debug.ListenAddr - } else { - DebugListenAddr = "0.0.0.0" - } - if Config.Debug.ListenPort != 0 { - DebugListenPort = Config.Debug.ListenPort - } else { - DebugListenPort = 10080 - } } diff --git a/config.json b/config.json index 47df6d2..47a50f5 100644 --- a/config.json +++ b/config.json @@ -1,7 +1,3 @@ { - "debugMode": true, - "debug": { - "listenAddr": "0.0.0.0", - "listenPort": 10080 - } -} \ No newline at end of file + "debugMode": true +} diff --git a/main.go b/main.go index e8a89b3..1592a41 100644 --- a/main.go +++ b/main.go @@ -7,12 +7,6 @@ import ( "fmt" ) -var ( - listenAddr = "127.0.0.1" - listenPort = 10080 - Type = "tcp" -) - type SoftwareInfo struct { Name string Version string @@ -33,15 +27,10 @@ var softwareInfo SoftwareInfo = SoftwareInfo{ var isDebug bool -var DebugListenAddr string -var DebugListenPort int - func main() { loadConfig() if isDebug == true { postLog.SetDebugMode(true) - listenAddr = DebugListenAddr - listenPort = DebugListenPort } postLog.Info(fmt.Sprintf("%s %s (Build %d.%s) by %s", softwareInfo.Name, softwareInfo.Version, softwareInfo.BuildVer, softwareInfo.BuildType, softwareInfo.Developer)) @@ -50,7 +39,7 @@ func main() { // End of command handler go func() { - err := socket.BootSocket(Type, listenAddr, listenPort) + err := socket.BootSocket() if err != nil { postLog.Fatal(fmt.Sprintf("Failed to initialize socket server: %v", err)) } diff --git a/socket/server.go b/socket/server.go index b83ce35..f6af252 100644 --- a/socket/server.go +++ b/socket/server.go @@ -1,87 +1,120 @@ -package socket - -import ( - "Watchdog_Linux-systemd/postLog" - "bufio" - "fmt" - "net" - "strings" -) - -var ( - Conn net.Conn - CommandHandler func(string) error -) - -func BootSocket(networkType, listenAddr string, listenPort int) error { - listen, err := net.Listen(networkType, fmt.Sprintf("%s:%d", listenAddr, listenPort)) - if err != nil { - postLog.Fatal(fmt.Sprintf("[Socket] Failed to listen: %v, err: %v, %v", listenAddr, listenPort, err)) - return fmt.Errorf("failed to listen: %v, err: %v, %v", listenAddr, listenPort, err) - } - defer listen.Close() - - postLog.Info(fmt.Sprintf("Server is running on %s:%d", listenAddr, listenPort)) - - for { - Conn, err = listen.Accept() - if err != nil { - postLog.Error(fmt.Sprintf("Failed to accept: %v, err: %v", Conn, err)) - } - go handleRequest() - } -} - -func handleRequest() { - defer Conn.Close() - - reader := bufio.NewReader(Conn) - - for { - data, err := reader.ReadBytes('\n') - if err != nil { - return - } - - recvMsg := strings.TrimSpace(string(data)) - - responseMsg := "" - if len(recvMsg) != 0 { - postLog.Debug(fmt.Sprintf("Received message: %s", recvMsg)) - if recvMsg == "watchdogAgentConnectionTest" { - responseMsg = "success" - } else { - if CommandHandler != nil { - err := CommandHandler(recvMsg) - if err != nil { - responseMsg = fmt.Sprintf("error: %v", err) - } else { - responseMsg = "success" - } - } else { - responseMsg = "error: command handler not initialized" - } - } - } - - Conn.Write([]byte(responseMsg + "\n")) - } -} - -func SendMsg(msg string) error { - if Conn == nil { - return fmt.Errorf("connection is nil") - } - - data := []byte(msg + "\n") - n, err := Conn.Write(data) - if err != nil { - return fmt.Errorf("failed to write message: %v", err) - } - - if n != len(data) { - return fmt.Errorf("incomplete write: wrote %d bytes out of %d", n, len(data)) - } - - return nil -} +package socket + +import ( + "Watchdog_Linux-systemd/postLog" + "bufio" + "fmt" + "net" + "os" + "strings" + "sync" +) + +const SocketPath = "/tmp/super-frpc-watchdog.sock" + +var ( + Conn net.Conn + connMutex sync.Mutex + CommandHandler func(string) error +) + +func BootSocket() error { + if err := os.RemoveAll(SocketPath); err != nil { + return fmt.Errorf("failed to remove stale socket %s: %w", SocketPath, err) + } + + listen, err := net.Listen("unix", SocketPath) + if err != nil { + postLog.Fatal(fmt.Sprintf("[Socket] Failed to listen on %s: %v", SocketPath, err)) + return fmt.Errorf("failed to listen on %s: %w", SocketPath, err) + } + defer func() { + listen.Close() + os.Remove(SocketPath) + }() + + postLog.Info(fmt.Sprintf("Server is running on local socket %s", SocketPath)) + + for { + conn, err := listen.Accept() + if err != nil { + postLog.Error(fmt.Sprintf("Failed to accept local socket connection: %v", err)) + continue + } + connMutex.Lock() + if Conn != nil { + Conn.Close() + } + Conn = conn + connMutex.Unlock() + + go handleRequest(conn) + } +} + +func handleRequest(conn net.Conn) { + defer func() { + conn.Close() + connMutex.Lock() + if Conn == conn { + Conn = nil + } + connMutex.Unlock() + }() + + reader := bufio.NewReader(conn) + + for { + data, err := reader.ReadBytes('\n') + if err != nil { + return + } + + recvMsg := strings.TrimSpace(string(data)) + + responseMsg := "" + if len(recvMsg) != 0 { + postLog.Debug(fmt.Sprintf("Received message: %s", recvMsg)) + if recvMsg == "watchdogAgentConnectionTest" { + responseMsg = "success" + } else { + if CommandHandler != nil { + err := CommandHandler(recvMsg) + if err != nil { + responseMsg = fmt.Sprintf("error: %v", err) + } else { + responseMsg = "success" + } + } else { + responseMsg = "error: command handler not initialized" + } + } + } + + if _, err := conn.Write([]byte(responseMsg + "\n")); err != nil { + return + } + } +} + +func SendMsg(msg string) error { + connMutex.Lock() + conn := Conn + connMutex.Unlock() + + if conn == nil { + return fmt.Errorf("connection is nil") + } + + data := []byte(msg + "\n") + n, err := conn.Write(data) + if err != nil { + return fmt.Errorf("failed to write message: %v", err) + } + + if n != len(data) { + return fmt.Errorf("incomplete write: wrote %d bytes out of %d", n, len(data)) + } + + return nil +}