From 2427ae9a39e3924d0986e6c91efb9fbe06ec65fd Mon Sep 17 00:00:00 2001 From: LTLA Date: Sat, 25 Jan 2025 11:56:48 -0800 Subject: [PATCH] Streamlined the locks for the active request registry. We eliminate the pooled locks for simplicity. The period of lock contention is so short - we're just adding/removing paths from an in-memory map - that it's not worth the complexity to handle a lock pool. The process of removing each request path from the registry is now in a goroutine that is spawned when the path is added to the registry. This guarantees that the expiry happens in a timely manner. We no longer remove the file at the request path as (i) the active registry protects against replays, (ii) we'd like to use it as a log of recent requests and (iii) we already have a job that purges old files anyway. The creation of a new activeRequestRegistry now requires that a staging directory be supplied; no point splitting the creation and prefilling. Similarly, the expiry time is stored in the registry object for convenience. --- main.go | 19 +-------- request.go | 76 ++++++++++++++---------------------- request_test.go | 100 +++++++++++++++++++++++++++--------------------- 3 files changed, 88 insertions(+), 107 deletions(-) diff --git a/main.go b/main.go index 7d85258..1b94a9f 100644 --- a/main.go +++ b/main.go @@ -70,9 +70,8 @@ func main() { } } - actreg := newActiveRequestRegistry(11) - const request_expiry = time.Minute - err := prefillActiveRequestRegistry(actreg, staging, request_expiry) + request_expiry := time.Minute + actreg, err := newActiveRequestRegistry(staging, request_expiry) if err != nil { log.Fatalf("failed to prefill active request registry; %v", err) } @@ -143,20 +142,6 @@ func main() { reportable_err = newHttpError(http.StatusBadRequest, errors.New("invalid request type")) } - // Purge the request file once it's processed, to reduce the potential - // for replay attacks. For safety's sake, we only remove it from the - // registry if the request file was properly deleted or it expired. - err = os.Remove(reqpath) - if err != nil { - log.Printf("failed to purge the request file at %q; %v", path, err) - go func() { - time.Sleep(request_expiry) - actreg.Remove(path) - }() - } else { - actreg.Remove(path) - } - if reportable_err == nil { payload["status"] = "SUCCESS" dumpJsonResponse(w, http.StatusOK, &payload, path) diff --git a/request.go b/request.go index 8f02a95..ebbe3da 100644 --- a/request.go +++ b/request.go @@ -13,74 +13,56 @@ import ( "syscall" ) -func chooseLockPool(path string, num_pools int) int { - sum := 0 - for _, r := range path { - sum += int(r) - } - return sum % num_pools -} // This tracks the requests that are currently being processed, to prevent the -// same request being processed multiple times at the same time. We use a -// multi-pool approach to improve parallelism across requests. +// same request being processed multiple times at the same time. type activeRequestRegistry struct { - NumPools int - Locks []sync.Mutex - Active []map[string]bool + Lock sync.Mutex + Active map[string]bool + Expiry time.Duration } -func newActiveRequestRegistry(num_pools int) *activeRequestRegistry { - return &activeRequestRegistry { - NumPools: num_pools, - Locks: make([]sync.Mutex, num_pools), - Active: make([]map[string]bool, num_pools), +func newActiveRequestRegistry(staging string, expiry time.Duration) (*activeRequestRegistry, error) { + output := &activeRequestRegistry { + Active: map[string]bool{}, + Expiry: expiry, } -} -func prefillActiveRequestRegistry(a *activeRequestRegistry, staging string, expiry time.Duration) error { // Prefilling the registry ensures that a user can't replay requests after a restart of the service. entries, err := os.ReadDir(staging) if err != nil { - return fmt.Errorf("failed to list existing request files in '%s'", staging) + return nil, fmt.Errorf("failed to list existing request files in '%s'", staging) } - // This is only necessary until the expiry time is exceeded, after which we can evict those entries. // Technically we only need to do this for files that weren't already expired, but this doesn't hurt. for _, e := range entries { path := e.Name() - a.Add(path) - go func(p string) { - time.Sleep(expiry) - a.Remove(p) - }(path) + output.Add(path) } - return nil + return output, nil } func (a *activeRequestRegistry) Add(path string) bool { - i := chooseLockPool(path, a.NumPools) - a.Locks[i].Lock() - defer a.Locks[i].Unlock() - - if a.Active[i] == nil { - a.Active[i] = map[string]bool{} - } else { - _, ok := a.Active[i][path] - if ok { - return false - } + a.Lock.Lock() + defer a.Lock.Unlock() + + _, ok := a.Active[path] + if ok { + return false } - - a.Active[i][path] = true - return true -} -func (a *activeRequestRegistry) Remove(path string) { - i := chooseLockPool(path, a.NumPools) - a.Locks[i].Lock() - defer a.Locks[i].Unlock() - delete(a.Active[i], path) + a.Active[path] = true + + // Once the request expires, we no longer need to protect against replay attacks, + // so we can delete it from the registry. + go func() { + time.Sleep(a.Expiry) + a.Lock.Lock() + defer a.Lock.Unlock() + delete(a.Active, path) + }() + + return true } func checkRequestFile(path, staging string, expiry time.Duration) (string, error) { diff --git a/request_test.go b/request_test.go index 77a75f5..f57be60 100644 --- a/request_test.go +++ b/request_test.go @@ -101,61 +101,75 @@ func TestCheckRequestFile(t *testing.T) { } func TestActiveRequestRegistry(t *testing.T) { - a := newActiveRequestRegistry(3) + t.Run("basic", func(t *testing.T) { + staging, err := os.MkdirTemp("", "") + if err != nil { + t.Fatal(err) + } - path := "adasdasdasd" - ok := a.Add(path) - if !ok { - t.Fatal("expected a successful addition") - } + a, err := newActiveRequestRegistry(staging, time.Millisecond * 200) - ok = a.Add(path) - if ok { - t.Fatal("expected a failed addition") - } + path := "adasdasdasd" + ok := a.Add(path) + if !ok { + t.Fatal("expected a successful addition") + } - a.Remove(path) - ok = a.Add(path) - if !ok { - t.Fatal("expected a successful addition again") - } + ok = a.Add(path) + if ok { + t.Fatal("expected a failed addition") + } - ok = a.Add("xyxyxyxyxyx") - if !ok { - t.Fatal("expected a successful addition again") - } -} + time.Sleep(time.Millisecond * 500) + ok = a.Add(path) + if !ok { + t.Fatal("expected a successful addition again") + } -func TestPrefillActiveRequestRegistry(t *testing.T) { - staging, err := os.MkdirTemp("", "") - if err != nil { - t.Fatal(err) - } + ok = a.Add("xyxyxyxyxyx") + if !ok { + t.Fatal("expected a successful addition again") + } + }) - names := []string{ "foo", "bar", "whee" } - for _, f := range names { - err = os.WriteFile(filepath.Join(staging, f), []byte{}, 0644) + t.Run("preloaded", func(t *testing.T) { + staging, err := os.MkdirTemp("", "") if err != nil { t.Fatal(err) } - } - a := newActiveRequestRegistry(3) - err = prefillActiveRequestRegistry(a, staging, time.Millisecond * 100) - if err != nil { - t.Fatal(err) - } + names := []string{ "foo", "bar", "whee" } + for _, f := range names { + err = os.WriteFile(filepath.Join(staging, f), []byte{}, 0644) + if err != nil { + t.Fatal(err) + } + } + + a, err := newActiveRequestRegistry(staging, time.Millisecond * 200) + if err != nil { + t.Fatal(err) + } - for _, f := range names { - if a.Add(f) { - t.Fatalf("%s should already be present in the registry", f) + for _, f := range names { + if a.Add(f) { + t.Fatalf("%s should already be present in the registry", f) + } } - } - time.Sleep(time.Millisecond * 200) - for _, f := range names { - if !a.Add(f) { - t.Fatalf("%s should have been removed from the registry", f) + // Adding some more names. + if !a.Add("stuff") { + t.Fatal("failed to add some new names") } - } + if a.Add("stuff") { + t.Fatal("should have failed to add a duplicate name") + } + + time.Sleep(time.Millisecond * 500) + for _, f := range names { + if !a.Add(f) { + t.Fatalf("%s should have been removed from the registry", f) + } + } + }) }