diff --git a/lib/denylist/http.go b/lib/denylist/http.go index 39599500..9086e730 100644 --- a/lib/denylist/http.go +++ b/lib/denylist/http.go @@ -3,10 +3,11 @@ package denylist import ( "encoding/json" "net/http" + "sync" ) // CollectionEndpoint serves the endpoints for the whole Denylist at /denylist -func CollectionEndpoint(denylist *map[string]bool) func(http.ResponseWriter, *http.Request) { +func CollectionEndpoint(denylist *sync.Map) func(http.ResponseWriter, *http.Request) { return func(response http.ResponseWriter, request *http.Request) { switch request.Method { case "GET": @@ -18,7 +19,7 @@ func CollectionEndpoint(denylist *map[string]bool) func(http.ResponseWriter, *ht } // SingleEndpoint serves the endpoints for particular Denylist entries at /denylist/... -func SingleEndpoint(denylist *map[string]bool) func(http.ResponseWriter, *http.Request) { +func SingleEndpoint(denylist *sync.Map) func(http.ResponseWriter, *http.Request) { return func(response http.ResponseWriter, request *http.Request) { switch request.Method { case "GET": @@ -34,14 +35,13 @@ func SingleEndpoint(denylist *map[string]bool) func(http.ResponseWriter, *http.R } // GET /denylist -func listDenylistKeys(response http.ResponseWriter, denylist *map[string]bool) { - keys := make([]string, len(*denylist)) +func listDenylistKeys(response http.ResponseWriter, denylist *sync.Map) { + keys := []interface{}{} - i := 0 - for k := range *denylist { - keys[i] = k - i++ - } + denylist.Range(func(key interface{}, value interface{}) bool { + keys = append(keys, key) + return true + }) response.Header().Set("Content-Type", "application/json") response.WriteHeader(http.StatusOK) @@ -53,9 +53,9 @@ func listDenylistKeys(response http.ResponseWriter, denylist *map[string]bool) { } // GET /denylist/... -func getDenylistEntry(response http.ResponseWriter, request *http.Request, denylist *map[string]bool) { +func getDenylistEntry(response http.ResponseWriter, request *http.Request, denylist *sync.Map) { id := request.URL.Path - _, exists := (*denylist)[id] + _, exists := denylist.Load(id) if !exists { http.Error(response, "denylist entry not found with that id", http.StatusNotFound) return @@ -71,27 +71,27 @@ func getDenylistEntry(response http.ResponseWriter, request *http.Request, denyl } // PUT /denylist/... -func createDenylistEntry(response http.ResponseWriter, request *http.Request, denylist *map[string]bool) { +func createDenylistEntry(response http.ResponseWriter, request *http.Request, denylist *sync.Map) { id := request.URL.Path - _, exists := (*denylist)[id] + _, exists := denylist.Load(id) if exists { response.WriteHeader(http.StatusNoContent) return } - (*denylist)[id] = true + denylist.Store(id, true) response.WriteHeader(http.StatusCreated) } // DELETE /denylist/... -func deleteDenylistEntry(response http.ResponseWriter, request *http.Request, denylist *map[string]bool) { +func deleteDenylistEntry(response http.ResponseWriter, request *http.Request, denylist *sync.Map) { id := request.URL.Path - _, exists := (*denylist)[id] + _, exists := denylist.Load(id) if !exists { http.Error(response, "denylist entry not found with that id", http.StatusNotFound) return } - delete(*denylist, id) + denylist.Delete(id) response.WriteHeader(http.StatusNoContent) } diff --git a/lib/oplog/tail.go b/lib/oplog/tail.go index 8a91ebcb..5eb799a1 100644 --- a/lib/oplog/tail.go +++ b/lib/oplog/tail.go @@ -7,6 +7,7 @@ import ( "context" "errors" "strings" + "sync" "time" "github.com/tulip/oplogtoredis/lib/config" @@ -29,7 +30,7 @@ type Tailer struct { RedisClients []redis.UniversalClient RedisPrefix string MaxCatchUp time.Duration - Denylist *map[string]bool + Denylist *sync.Map } // Raw oplog entry from Mongo diff --git a/lib/oplog/tail_test.go b/lib/oplog/tail_test.go index 4d26e92c..81fcd0b4 100644 --- a/lib/oplog/tail_test.go +++ b/lib/oplog/tail_test.go @@ -3,6 +3,7 @@ package oplog import ( "errors" "strconv" + "sync" "testing" "time" @@ -83,7 +84,7 @@ func TestGetStartTime(t *testing.T) { RedisClients: redisClient, RedisPrefix: "someprefix.", MaxCatchUp: maxCatchUp, - Denylist: &map[string]bool{}, + Denylist: &sync.Map{}, } actualResult := tailer.getStartTime(func() (*primitive.Timestamp, error) { @@ -286,7 +287,7 @@ func TestParseRawOplogEntry(t *testing.T) { for testName, test := range tests { t.Run(testName, func(t *testing.T) { - got := (&Tailer{Denylist: &map[string]bool{}}).parseRawOplogEntry(test.in, nil) + got := (&Tailer{Denylist: &sync.Map{}}).parseRawOplogEntry(test.in, nil) if diff := pretty.Compare(got, test.want); diff != "" { t.Errorf("Got incorrect result (-got +want)\n%s", diff) diff --git a/main.go b/main.go index 28c31343..1aab503d 100644 --- a/main.go +++ b/main.go @@ -62,7 +62,7 @@ func main() { bufferSize := 10000 waitGroup := sync.WaitGroup{} - denylist := map[string]bool{} + denylist := sync.Map{} for i := 0; i < config.WriteParallelism(); i++ { redisClients, err := createRedisClients() @@ -249,7 +249,7 @@ func createRedisClients() ([]redis.UniversalClient, error) { return ret, nil } -func makeHTTPServer(aggregatedClients [][]redis.UniversalClient, mongo *mongo.Client, denylistMap *map[string]bool) *http.Server { +func makeHTTPServer(aggregatedClients [][]redis.UniversalClient, mongo *mongo.Client, denylistMap *sync.Map) *http.Server { mux := http.NewServeMux() mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {