From cda7e795f576445637e659bbf5393be310c79176 Mon Sep 17 00:00:00 2001 From: LTLA Date: Sat, 13 Apr 2024 12:22:57 -0700 Subject: [PATCH] Prevent replay attacks on old request files from other users. This uses a lock across all handlers to ensure that the same request file isn't being used multiple times at any given period - we then delete the request file afterwards to guarantee that it can't be used again. --- main.go | 86 +++++++++++++++++++++++++++++++++++++++++++++++----- main_test.go | 26 ++++++++++++++++ 2 files changed, 104 insertions(+), 8 deletions(-) diff --git a/main.go b/main.go index e52e993..78b898c 100644 --- a/main.go +++ b/main.go @@ -14,6 +14,7 @@ import ( "strconv" "io/fs" "syscall" + "sync" ) func dumpJsonResponse(w http.ResponseWriter, status int, v interface{}, path string) { @@ -33,20 +34,19 @@ func dumpJsonResponse(w http.ResponseWriter, status int, v interface{}, path str } } -func dumpErrorResponse(w http.ResponseWriter, status int, message string, path string) { - log.Printf("failed to process %q; %s\n", path, message) - dumpJsonResponse(w, status, map[string]interface{}{ "status": "ERROR", "reason": message }, path) -} - func dumpHttpErrorResponse(w http.ResponseWriter, err error, path string) { status_code := http.StatusInternalServerError var http_err *httpError if errors.As(err, &http_err) { status_code = http_err.Status } - dumpErrorResponse(w, status_code, err.Error(), path) + message := err.Error() + log.Printf("failed to process %q; %s\n", path, message) + dumpJsonResponse(w, status_code, map[string]interface{}{ "status": "ERROR", "reason": message }, path) } +/***************************************************/ + func checkRequestFile(path, staging string) (string, error) { if !strings.HasPrefix(path, "request-") { return "", newHttpError(http.StatusBadRequest, errors.New("file name should start with \"request-\"")) @@ -81,6 +81,60 @@ func checkRequestFile(path, staging string) (string, error) { return reqpath, nil } +/***************************************************/ + +// This tracks the requests that are currently being processed, to prevent the +// same request being processed multiple times at the same time. We use a +// multi-pool approach to improve parallelism across requests. +type activeRegistry struct { + NumPools int + Locks []sync.Mutex + Active []map[string]bool +} + +func newActiveRegistry(num_pools int) *activeRegistry { + return &activeRegistry { + NumPools: num_pools, + Locks: make([]sync.Mutex, num_pools), + Active: make([]map[string]bool, num_pools), + } +} + +func (a *activeRegistry) choosePool(path string) int { + sum := 0 + for _, r := range path { + sum += int(r) + } + return sum % a.NumPools +} + +func (a *activeRegistry) Add(path string) bool { + i := a.choosePool(path) + a.Locks[i].Lock() + defer a.Locks[i].Unlock() + + if a.Active[i] == nil { + a.Active[i] = map[string]bool{} + } else { + _, ok := a.Active[i][path] + if ok { + return false + } + } + + a.Active[i][path] = true + return true +} + +func (a *activeRegistry) Remove(path string) { + i := a.choosePool(path) + a.Locks[i].Lock() + defer a.Locks[i].Unlock() + delete(a.Active[i], path) +} + +/***************************************************/ + func main() { spath := flag.String("staging", "", "Path to the staging directory.") rpath := flag.String("registry", "", "Path to the registry.") @@ -107,6 +161,8 @@ func main() { } } + actreg := newActiveRegistry(11) + // Creating an endpoint to trigger jobs. http.HandleFunc("POST /new/{path}", func(w http.ResponseWriter, r *http.Request) { path := r.PathValue("path") @@ -118,6 +174,11 @@ func main() { return } + if !actreg.Add(path) { + dumpHttpErrorResponse(w, newHttpError(http.StatusBadRequest, errors.New("path is already being processed")), path) + return + } + var reportable_err error payload := map[string]interface{}{} reqtype := strings.TrimPrefix(path, "request-") @@ -160,8 +221,17 @@ func main() { } else if strings.HasPrefix(reqtype, "health_check-") { // TO-BE-DEPRECATED, see /check below. reportable_err = nil } else { - dumpErrorResponse(w, http.StatusBadRequest, "invalid request type", reqpath) - return + reportable_err = newHttpError(http.StatusBadRequest, errors.New("invalid request type")) + } + + // Purge the request file once it's processed, to reduce the potential + // for replay attacks. For safety's sake, we only remove it from the + // registry if the request file was properly deleted. + err = os.Remove(reqpath) + if err != nil { + log.Printf("failed to purge the request file at %q; %v", path, err) + } else { + actreg.Remove(path) } if reportable_err == nil { diff --git a/main_test.go b/main_test.go index 3dc10b7..bb2c8b6 100644 --- a/main_test.go +++ b/main_test.go @@ -85,3 +85,29 @@ func TestCheckRequestFile(t *testing.T) { } }) } + +func TestActiveRegistry(t *testing.T) { + a := newActiveRegistry(3) + + path := "adasdasdasd" + ok := a.Add(path) + if !ok { + t.Fatal("expected a successful addition") + } + + ok = a.Add(path) + if ok { + t.Fatal("expected a failed addition") + } + + a.Remove(path) + ok = a.Add(path) + if !ok { + t.Fatal("expected a successful addition again") + } + + ok = a.Add("xyxyxyxyxyx") + if !ok { + t.Fatal("expected a successful addition again") + } +}