Skip to content

Commit

Permalink
use sync map
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-goodisman committed Apr 25, 2024
1 parent 8e93c5f commit 9df196b
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 22 deletions.
34 changes: 17 additions & 17 deletions lib/denylist/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ package denylist
import (
"encoding/json"
"net/http"
"sync"
)

// CollectionEndpoint serves the endpoints for the whole Denylist at /denylist
func CollectionEndpoint(denylist *map[string]bool) func(http.ResponseWriter, *http.Request) {
func CollectionEndpoint(denylist *sync.Map) func(http.ResponseWriter, *http.Request) {
return func(response http.ResponseWriter, request *http.Request) {
switch request.Method {
case "GET":
Expand All @@ -18,7 +19,7 @@ func CollectionEndpoint(denylist *map[string]bool) func(http.ResponseWriter, *ht
}

// SingleEndpoint serves the endpoints for particular Denylist entries at /denylist/...
func SingleEndpoint(denylist *map[string]bool) func(http.ResponseWriter, *http.Request) {
func SingleEndpoint(denylist *sync.Map) func(http.ResponseWriter, *http.Request) {
return func(response http.ResponseWriter, request *http.Request) {
switch request.Method {
case "GET":
Expand All @@ -34,14 +35,13 @@ func SingleEndpoint(denylist *map[string]bool) func(http.ResponseWriter, *http.R
}

// GET /denylist
func listDenylistKeys(response http.ResponseWriter, denylist *map[string]bool) {
keys := make([]string, len(*denylist))
func listDenylistKeys(response http.ResponseWriter, denylist *sync.Map) {
keys := []interface{}{}

i := 0
for k := range *denylist {
keys[i] = k
i++
}
denylist.Range(func(key interface{}, value interface{}) bool {
keys = append(keys, key)
return true
})

response.Header().Set("Content-Type", "application/json")
response.WriteHeader(http.StatusOK)
Expand All @@ -53,9 +53,9 @@ func listDenylistKeys(response http.ResponseWriter, denylist *map[string]bool) {
}

// GET /denylist/...
func getDenylistEntry(response http.ResponseWriter, request *http.Request, denylist *map[string]bool) {
func getDenylistEntry(response http.ResponseWriter, request *http.Request, denylist *sync.Map) {
id := request.URL.Path
_, exists := (*denylist)[id]
_, exists := denylist.Load(id)
if !exists {
http.Error(response, "denylist entry not found with that id", http.StatusNotFound)
return
Expand All @@ -71,27 +71,27 @@ func getDenylistEntry(response http.ResponseWriter, request *http.Request, denyl
}

// PUT /denylist/...
func createDenylistEntry(response http.ResponseWriter, request *http.Request, denylist *map[string]bool) {
func createDenylistEntry(response http.ResponseWriter, request *http.Request, denylist *sync.Map) {
id := request.URL.Path
_, exists := (*denylist)[id]
_, exists := denylist.Load(id)
if exists {
response.WriteHeader(http.StatusNoContent)
return
}

(*denylist)[id] = true
denylist.Store(id, true)
response.WriteHeader(http.StatusCreated)
}

// DELETE /denylist/...
func deleteDenylistEntry(response http.ResponseWriter, request *http.Request, denylist *map[string]bool) {
func deleteDenylistEntry(response http.ResponseWriter, request *http.Request, denylist *sync.Map) {
id := request.URL.Path
_, exists := (*denylist)[id]
_, exists := denylist.Load(id)
if !exists {
http.Error(response, "denylist entry not found with that id", http.StatusNotFound)
return
}

delete(*denylist, id)
denylist.Delete(id)
response.WriteHeader(http.StatusNoContent)
}
3 changes: 2 additions & 1 deletion lib/oplog/tail.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"errors"
"strings"
"sync"
"time"

"github.com/tulip/oplogtoredis/lib/config"
Expand All @@ -29,7 +30,7 @@ type Tailer struct {
RedisClients []redis.UniversalClient
RedisPrefix string
MaxCatchUp time.Duration
Denylist *map[string]bool
Denylist *sync.Map
}

// Raw oplog entry from Mongo
Expand Down
5 changes: 3 additions & 2 deletions lib/oplog/tail_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package oplog
import (
"errors"
"strconv"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -83,7 +84,7 @@ func TestGetStartTime(t *testing.T) {
RedisClients: redisClient,
RedisPrefix: "someprefix.",
MaxCatchUp: maxCatchUp,
Denylist: &map[string]bool{},
Denylist: &sync.Map{},
}

actualResult := tailer.getStartTime(func() (*primitive.Timestamp, error) {
Expand Down Expand Up @@ -286,7 +287,7 @@ func TestParseRawOplogEntry(t *testing.T) {

for testName, test := range tests {
t.Run(testName, func(t *testing.T) {
got := (&Tailer{Denylist: &map[string]bool{}}).parseRawOplogEntry(test.in, nil)
got := (&Tailer{Denylist: &sync.Map{}}).parseRawOplogEntry(test.in, nil)

if diff := pretty.Compare(got, test.want); diff != "" {
t.Errorf("Got incorrect result (-got +want)\n%s", diff)
Expand Down
4 changes: 2 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func main() {

bufferSize := 10000
waitGroup := sync.WaitGroup{}
denylist := map[string]bool{}
denylist := sync.Map{}

for i := 0; i < config.WriteParallelism(); i++ {
redisClients, err := createRedisClients()
Expand Down Expand Up @@ -249,7 +249,7 @@ func createRedisClients() ([]redis.UniversalClient, error) {
return ret, nil
}

func makeHTTPServer(aggregatedClients [][]redis.UniversalClient, mongo *mongo.Client, denylistMap *map[string]bool) *http.Server {
func makeHTTPServer(aggregatedClients [][]redis.UniversalClient, mongo *mongo.Client, denylistMap *sync.Map) *http.Server {
mux := http.NewServeMux()

mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
Expand Down

0 comments on commit 9df196b

Please sign in to comment.