diff --git a/internal/config/config.go b/internal/config/config.go index 205ee02..9a704b2 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,18 +3,29 @@ package config import ( "flag" "fmt" + "path/filepath" + + "github.com/charmbracelet/log" ) var ( - Addr string - Path string + Addr string + Path string + Debug bool ) func init() { port := flag.Int("port", 0, "port to listen on") flag.StringVar(&Path, "path", ".", "path to serve") + flag.BoolVar(&Debug, "debug", false, "enable debug logging") flag.Parse() Addr = fmt.Sprintf(":%d", *port) + + if p, err := filepath.Abs(Path); err == nil { + Path = p + } else { + log.Fatal(err) + } } diff --git a/internal/net/net.go b/internal/net/net.go index bbad970..5655b5b 100644 --- a/internal/net/net.go +++ b/internal/net/net.go @@ -47,91 +47,89 @@ func UploadHandler(w http.ResponseWriter, r *http.Request) { } defer conn.Close() - // Step 1: Receive and decode the file header - var header Header - _, message, err := conn.ReadMessage() - if err != nil { - log.Error("failed to read message", "err", err) - return - } - - if err := json.Unmarshal(message, &header); err != nil { - log.Error("failed to unmarshal header", "err", err) - return - } + for { + // Step 1: Receive and decode the file header + var header Header + _, message, err := conn.ReadMessage() + if err != nil { + log.Error("failed to read message", "err", err) + return + } - // Normalize and verify the file path - filePath, err := normalizePath(header.Name) - if err != nil { - log.Error("failed to normalize file path", "err", err) - conn.WriteMessage(websocket.TextMessage, []byte("Invalid file path")) - return - } + // Check if the message is "EOT" (end of transmission) + if string(message) == "EOT" { + log.Debug("end of transmission") + break + } - // Ensure the file path is within the server's working directory - if !isInWorkingDir(filePath) { - log.Error("file path outside working directory") - conn.WriteMessage(websocket.TextMessage, []byte("File path outside working directory")) - return - } + if err := json.Unmarshal(message, &header); err != nil { + log.Error("failed to unmarshal header", "err", err) + return + } - // Create the target directory if needed - if err := os.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil { - log.Error("failed to create directory", "err", err) - conn.WriteMessage(websocket.TextMessage, []byte("Directory creation error")) - return - } + // Normalize and verify the file path + filePath := normalizePath(header.Name) - // Open the target file for writing - file, err := os.Create(filePath) - if err != nil { - log.Error("failed to create file", "err", err) - conn.WriteMessage(websocket.TextMessage, []byte("File creation error")) - return - } - defer file.Close() + // Ensure the file path is within the server's working directory + if !isInWorkingDir(filePath) { + log.Error("file path outside working directory") + conn.WriteMessage(websocket.TextMessage, []byte("File path outside working directory")) + return + } - // Send "READY" message to the client - conn.WriteMessage(websocket.TextMessage, []byte("READY")) + // Create the target directory if needed + if err := os.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil { + log.Error("failed to create directory", "err", err) + conn.WriteMessage(websocket.TextMessage, []byte("Directory creation error")) + return + } - // Step 2: Receive and write file chunks - for { - _, message, err := conn.ReadMessage() + // Open the target file for writing + file, err := os.Create(filePath) if err != nil { - log.Error("failed to read message", "err", err) + log.Error("failed to create file", "err", err) + conn.WriteMessage(websocket.TextMessage, []byte("File creation error")) return } + defer file.Close() - // Check for EOF - if string(message) == "EOF" { - log.Debug("end of file") - break + // Send "READY" message to the client + conn.WriteMessage(websocket.TextMessage, []byte("READY")) + + // Step 2: Receive and write file chunks + for { + _, message, err := conn.ReadMessage() + if err != nil { + log.Error("failed to read message", "err", err) + return + } + + // Check for EOF + if string(message) == "EOF" { + log.Debug("end of file") + break + } + + // Write the chunk to the file + if _, err := file.Write(message); err != nil { + log.Error("failed to write file chunk", "err", err) + return + } } - // Write the chunk to the file - if _, err := file.Write(message); err != nil { - log.Error("failed to write file chunk", "err", err) - return + // Optionally, set the file's last modified time + modTime := time.Unix(header.LastModified, 0) + if err := os.Chtimes(filePath, modTime, modTime); err != nil { + log.Error("failed to set last modified time", "err", err) } - } - // Optionally, set the file's last modified time - modTime := time.Unix(header.LastModified, 0) - if err := os.Chtimes(filePath, modTime, modTime); err != nil { - log.Error("failed to set last modified time", "err", err) + log.Info("file copy successful", "file", filePath) } - - log.Info("file copy successful") } // normalizePath cleans and returns the absolute path of the file. -func normalizePath(name string) (string, error) { - // Clean and make the path absolute - absPath, err := filepath.Abs(filepath.Clean(name)) - if err != nil { - return "", err - } - return absPath, nil +func normalizePath(name string) string { + return filepath.Join(config.Path, filepath.Clean(name)) } // isInWorkingDir checks if a path is within the server's current working directory. @@ -188,7 +186,7 @@ func StartServer() error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := server.Shutdown(ctx); err != nil { - log.Fatalf("Server forced to shutdown: %v", err) + log.Warn("Server forced to shutdown", "err", err) } }() diff --git a/internal/public/index.js b/internal/public/index.js index 2abf4e6..e3e9929 100644 --- a/internal/public/index.js +++ b/internal/public/index.js @@ -37,7 +37,7 @@ const bandwidthCalculator = { /** handles errors sent from the Go server */ class ServerError extends Error { constructor(message) { - super("unexpected message from server"); + super("unexpected message from server: " + message); this._serverMessage = message; } diff --git a/main.go b/main.go index 3f5a545..481016b 100644 --- a/main.go +++ b/main.go @@ -1,9 +1,9 @@ package main import ( - "log" "os" + "github.com/charmbracelet/log" "github.com/tsukinoko-kun/portal/internal/config" "github.com/tsukinoko-kun/portal/internal/net" ) @@ -13,6 +13,10 @@ func main() { log.Fatal(err) } + if config.Debug { + log.SetLevel(log.DebugLevel) + } + if err := net.StartServer(); err != nil { log.Fatal(err) }