diff --git a/main.go b/main.go index 14fccea..5b4fa8e 100644 --- a/main.go +++ b/main.go @@ -22,8 +22,10 @@ package main import ( "fmt" + "io" "net/http" "os" + "path/filepath" "strings" "text/tabwriter" "time" @@ -54,6 +56,7 @@ func main() { version := pflag.BoolP("version", "v", false, "prints the version") dontDump := pflag.BoolP("dont-dump", "V", false, "don't be verbose and do not dump requests") dontServe := pflag.BoolP("dont-serve", "D", false, "don't serve any directy (ignores --directory)") + enableUpload := pflag.BoolP("enable-upload", "u", os.Getenv("LAMA_UPLOAD") == "true", "enable support for file uploads") pflag.Parse() if *version { @@ -72,6 +75,11 @@ func main() { if !*dontServe { handler.Delegate = http.FileServer(http.Dir(*dir)) fileStatement = fmt.Sprintf("files from %s ", *dir) + + if *enableUpload { + handler.Delegate = handleUploads(*dir, handler.Delegate) + fileStatement += "(with upload support) " + } } http.Handle("/", handler) @@ -145,3 +153,51 @@ func (h *debugHandler) logRequests() { fmt.Println() } } + +func handleUploads(dir string, handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var ( + filename string + filebody io.Reader + ) + switch { + case r.Method == http.MethodPut: + defer r.Body.Close() + filename = r.URL.Path + filebody = r.Body + case r.Method == http.MethodPost && r.Header.Get("Content-Type") == "multipart/form-data": + defer r.Body.Close() + file, fileHeader, err := r.FormFile("file") + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + defer file.Close() + filename = fileHeader.Filename + filebody = file + default: + handler.ServeHTTP(w, r) + return + } + + tmpfile, err := os.CreateTemp(dir, "lama-upload-*") + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer tmpfile.Close() + defer os.Remove(tmpfile.Name()) + + _, err = io.Copy(tmpfile, filebody) + if err != nil { + http.Error(w, fmt.Sprintf("failed to write file body: %v", err), http.StatusInternalServerError) + return + } + + err = os.Rename(tmpfile.Name(), filepath.Join(dir, filename)) + if err != nil { + http.Error(w, fmt.Sprintf("failed to rename file: %v", err), http.StatusInternalServerError) + return + } + }) +}