Skip to content

Commit

Permalink
Prevent replay attacks on old request files from other users.
Browse files Browse the repository at this point in the history
This uses a lock across all handlers to ensure that the same request file isn't
being used multiple times at any given period - we then delete the request file
afterwards to guarantee that it can't be used again.
  • Loading branch information
LTLA committed Apr 13, 2024
1 parent 0a75509 commit cda7e79
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 8 deletions.
86 changes: 78 additions & 8 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"strconv"
"io/fs"
"syscall"
"sync"
)

func dumpJsonResponse(w http.ResponseWriter, status int, v interface{}, path string) {
Expand All @@ -33,20 +34,19 @@ func dumpJsonResponse(w http.ResponseWriter, status int, v interface{}, path str
}
}

func dumpErrorResponse(w http.ResponseWriter, status int, message string, path string) {
log.Printf("failed to process %q; %s\n", path, message)
dumpJsonResponse(w, status, map[string]interface{}{ "status": "ERROR", "reason": message }, path)
}

func dumpHttpErrorResponse(w http.ResponseWriter, err error, path string) {
status_code := http.StatusInternalServerError
var http_err *httpError
if errors.As(err, &http_err) {
status_code = http_err.Status
}
dumpErrorResponse(w, status_code, err.Error(), path)
message := err.Error()
log.Printf("failed to process %q; %s\n", path, message)
dumpJsonResponse(w, status_code, map[string]interface{}{ "status": "ERROR", "reason": message }, path)
}

/***************************************************/

func checkRequestFile(path, staging string) (string, error) {
if !strings.HasPrefix(path, "request-") {
return "", newHttpError(http.StatusBadRequest, errors.New("file name should start with \"request-\""))
Expand Down Expand Up @@ -81,6 +81,60 @@ func checkRequestFile(path, staging string) (string, error) {
return reqpath, nil
}

/***************************************************/

// This tracks the requests that are currently being processed, to prevent the
// same request being processed multiple times at the same time. We use a
// multi-pool approach to improve parallelism across requests.
type activeRegistry struct {
NumPools int
Locks []sync.Mutex
Active []map[string]bool
}

func newActiveRegistry(num_pools int) *activeRegistry {
return &activeRegistry {
NumPools: num_pools,
Locks: make([]sync.Mutex, num_pools),
Active: make([]map[string]bool, num_pools),
}
}

func (a *activeRegistry) choosePool(path string) int {
sum := 0
for _, r := range path {
sum += int(r)
}
return sum % a.NumPools
}

func (a *activeRegistry) Add(path string) bool {
i := a.choosePool(path)
a.Locks[i].Lock()
defer a.Locks[i].Unlock()

if a.Active[i] == nil {
a.Active[i] = map[string]bool{}
} else {
_, ok := a.Active[i][path]
if ok {
return false
}
}

a.Active[i][path] = true
return true
}

func (a *activeRegistry) Remove(path string) {
i := a.choosePool(path)
a.Locks[i].Lock()
defer a.Locks[i].Unlock()
delete(a.Active[i], path)
}

/***************************************************/

func main() {
spath := flag.String("staging", "", "Path to the staging directory.")
rpath := flag.String("registry", "", "Path to the registry.")
Expand All @@ -107,6 +161,8 @@ func main() {
}
}

actreg := newActiveRegistry(11)

// Creating an endpoint to trigger jobs.
http.HandleFunc("POST /new/{path}", func(w http.ResponseWriter, r *http.Request) {
path := r.PathValue("path")
Expand All @@ -118,6 +174,11 @@ func main() {
return
}

if !actreg.Add(path) {
dumpHttpErrorResponse(w, newHttpError(http.StatusBadRequest, errors.New("path is already being processed")), path)
return
}

var reportable_err error
payload := map[string]interface{}{}
reqtype := strings.TrimPrefix(path, "request-")
Expand Down Expand Up @@ -160,8 +221,17 @@ func main() {
} else if strings.HasPrefix(reqtype, "health_check-") { // TO-BE-DEPRECATED, see /check below.
reportable_err = nil
} else {
dumpErrorResponse(w, http.StatusBadRequest, "invalid request type", reqpath)
return
reportable_err = newHttpError(http.StatusBadRequest, errors.New("invalid request type"))
}

// Purge the request file once it's processed, to reduce the potential
// for replay attacks. For safety's sake, we only remove it from the
// registry if the request file was properly deleted.
err = os.Remove(reqpath)
if err != nil {
log.Printf("failed to purge the request file at %q; %v", path, err)
} else {
actreg.Remove(path)
}

if reportable_err == nil {
Expand Down
26 changes: 26 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,29 @@ func TestCheckRequestFile(t *testing.T) {
}
})
}

func TestActiveRegistry(t *testing.T) {
a := newActiveRegistry(3)

path := "adasdasdasd"
ok := a.Add(path)
if !ok {
t.Fatal("expected a successful addition")
}

ok = a.Add(path)
if ok {
t.Fatal("expected a failed addition")
}

a.Remove(path)
ok = a.Add(path)
if !ok {
t.Fatal("expected a successful addition again")
}

ok = a.Add("xyxyxyxyxyx")
if !ok {
t.Fatal("expected a successful addition again")
}
}

0 comments on commit cda7e79

Please sign in to comment.