Skip to content

Commit

Permalink
fix: handle directories
Browse files Browse the repository at this point in the history
  • Loading branch information
tsukinoko-kun committed Nov 9, 2024
1 parent 5014461 commit 0e0edf1
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 72 deletions.
15 changes: 13 additions & 2 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,29 @@ package config
import (
"flag"
"fmt"
"path/filepath"

"github.com/charmbracelet/log"
)

var (
Addr string
Path string
Addr string
Path string
Debug bool
)

func init() {
port := flag.Int("port", 0, "port to listen on")
flag.StringVar(&Path, "path", ".", "path to serve")
flag.BoolVar(&Debug, "debug", false, "enable debug logging")

flag.Parse()

Addr = fmt.Sprintf(":%d", *port)

if p, err := filepath.Abs(Path); err == nil {
Path = p
} else {
log.Fatal(err)
}
}
134 changes: 66 additions & 68 deletions internal/net/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,91 +47,89 @@ func UploadHandler(w http.ResponseWriter, r *http.Request) {
}
defer conn.Close()

// Step 1: Receive and decode the file header
var header Header
_, message, err := conn.ReadMessage()
if err != nil {
log.Error("failed to read message", "err", err)
return
}

if err := json.Unmarshal(message, &header); err != nil {
log.Error("failed to unmarshal header", "err", err)
return
}
for {
// Step 1: Receive and decode the file header
var header Header
_, message, err := conn.ReadMessage()
if err != nil {
log.Error("failed to read message", "err", err)
return
}

// Normalize and verify the file path
filePath, err := normalizePath(header.Name)
if err != nil {
log.Error("failed to normalize file path", "err", err)
conn.WriteMessage(websocket.TextMessage, []byte("Invalid file path"))
return
}
// Check if the message is "EOT" (end of transmission)
if string(message) == "EOT" {
log.Debug("end of transmission")
break
}

// Ensure the file path is within the server's working directory
if !isInWorkingDir(filePath) {
log.Error("file path outside working directory")
conn.WriteMessage(websocket.TextMessage, []byte("File path outside working directory"))
return
}
if err := json.Unmarshal(message, &header); err != nil {
log.Error("failed to unmarshal header", "err", err)
return
}

// Create the target directory if needed
if err := os.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil {
log.Error("failed to create directory", "err", err)
conn.WriteMessage(websocket.TextMessage, []byte("Directory creation error"))
return
}
// Normalize and verify the file path
filePath := normalizePath(header.Name)

// Open the target file for writing
file, err := os.Create(filePath)
if err != nil {
log.Error("failed to create file", "err", err)
conn.WriteMessage(websocket.TextMessage, []byte("File creation error"))
return
}
defer file.Close()
// Ensure the file path is within the server's working directory
if !isInWorkingDir(filePath) {
log.Error("file path outside working directory")
conn.WriteMessage(websocket.TextMessage, []byte("File path outside working directory"))
return
}

// Send "READY" message to the client
conn.WriteMessage(websocket.TextMessage, []byte("READY"))
// Create the target directory if needed
if err := os.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil {
log.Error("failed to create directory", "err", err)
conn.WriteMessage(websocket.TextMessage, []byte("Directory creation error"))
return
}

// Step 2: Receive and write file chunks
for {
_, message, err := conn.ReadMessage()
// Open the target file for writing
file, err := os.Create(filePath)
if err != nil {
log.Error("failed to read message", "err", err)
log.Error("failed to create file", "err", err)
conn.WriteMessage(websocket.TextMessage, []byte("File creation error"))
return
}
defer file.Close()

// Check for EOF
if string(message) == "EOF" {
log.Debug("end of file")
break
// Send "READY" message to the client
conn.WriteMessage(websocket.TextMessage, []byte("READY"))

// Step 2: Receive and write file chunks
for {
_, message, err := conn.ReadMessage()
if err != nil {
log.Error("failed to read message", "err", err)
return
}

// Check for EOF
if string(message) == "EOF" {
log.Debug("end of file")
break
}

// Write the chunk to the file
if _, err := file.Write(message); err != nil {
log.Error("failed to write file chunk", "err", err)
return
}
}

// Write the chunk to the file
if _, err := file.Write(message); err != nil {
log.Error("failed to write file chunk", "err", err)
return
// Optionally, set the file's last modified time
modTime := time.Unix(header.LastModified, 0)
if err := os.Chtimes(filePath, modTime, modTime); err != nil {
log.Error("failed to set last modified time", "err", err)
}
}

// Optionally, set the file's last modified time
modTime := time.Unix(header.LastModified, 0)
if err := os.Chtimes(filePath, modTime, modTime); err != nil {
log.Error("failed to set last modified time", "err", err)
log.Info("file copy successful", "file", filePath)
}

log.Info("file copy successful")
}

// normalizePath cleans and returns the absolute path of the file.
func normalizePath(name string) (string, error) {
// Clean and make the path absolute
absPath, err := filepath.Abs(filepath.Clean(name))
if err != nil {
return "", err
}
return absPath, nil
func normalizePath(name string) string {
return filepath.Join(config.Path, filepath.Clean(name))
}

// isInWorkingDir checks if a path is within the server's current working directory.
Expand Down Expand Up @@ -188,7 +186,7 @@ func StartServer() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
log.Fatalf("Server forced to shutdown: %v", err)
log.Warn("Server forced to shutdown", "err", err)
}
}()

Expand Down
2 changes: 1 addition & 1 deletion internal/public/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ const bandwidthCalculator = {
/** handles errors sent from the Go server */
class ServerError extends Error {
constructor(message) {
super("unexpected message from server");
super("unexpected message from server: " + message);
this._serverMessage = message;
}

Expand Down
6 changes: 5 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package main

import (
"log"
"os"

"github.com/charmbracelet/log"
"github.com/tsukinoko-kun/portal/internal/config"
"github.com/tsukinoko-kun/portal/internal/net"
)
Expand All @@ -13,6 +13,10 @@ func main() {
log.Fatal(err)
}

if config.Debug {
log.SetLevel(log.DebugLevel)
}

if err := net.StartServer(); err != nil {
log.Fatal(err)
}
Expand Down

0 comments on commit 0e0edf1

Please sign in to comment.