diff --git a/default.nix b/default.nix index e5e2f96a..ec09e7a7 100644 --- a/default.nix +++ b/default.nix @@ -9,7 +9,7 @@ buildGoModule { ''; # update: set value to an empty string and run `nix build`. This will download Go, fetch the dependencies and calculates their hash. - vendorHash = "sha256-ceToA2DC1bhmg9WIeNSAfoNoU7sk9PrQqgqt5UbpivQ="; + vendorHash = "sha256-Vh7O0iMPG6nAvcyv92h5TVZS2awnR0vz75apyzJeu4c="; nativeBuildInputs = [ installShellFiles ]; doCheck = false; diff --git a/go.mod b/go.mod index 762f455a..83dac2b6 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/alicebob/miniredis v2.5.0+incompatible github.com/deckarep/golang-set v1.7.1 github.com/go-redis/redis/v7 v7.4.1 + github.com/go-redis/redis/v8 v8.11.5 github.com/gorilla/websocket v1.4.2 github.com/juju/mgo/v2 v2.0.0-20210302023703-70d5d206e208 github.com/juju/replicaset v0.0.0-20210302050932-0303c8575745 @@ -29,10 +30,10 @@ require ( github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/go-redis/redis/v8 v8.11.5 // indirect github.com/golang/protobuf v1.4.3 // indirect github.com/golang/snappy v0.0.1 // indirect github.com/gomodule/redigo v1.8.5 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/juju/clock v0.0.0-20190205081909-9c5c9712527c // indirect github.com/juju/errors v0.0.0-20200330140219-3fe23663418f // indirect github.com/juju/loggo v0.0.0-20200526014432-9ce3a2e09b5e // indirect diff --git a/go.sum b/go.sum index 731a13b0..a75f916c 100644 --- a/go.sum +++ b/go.sum @@ -140,6 +140,8 @@ github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hf github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= diff --git a/integration-tests/acceptance/denylist_http_test.go b/integration-tests/acceptance/denylist_http_test.go new file mode 100644 index 00000000..5ecc0534 --- /dev/null +++ b/integration-tests/acceptance/denylist_http_test.go @@ -0,0 +1,88 @@ +package main + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "os" + "reflect" + "testing" +) + +func doRequest(method string, path string, t *testing.T, expectedCode int) interface{} { + req, err := http.NewRequest(method, os.Getenv("OTR_URL")+path, &bytes.Buffer{}) + if err != nil { + t.Fatalf("Error creating req: %s", err) + } + req.Header.Set("Content-Type", "application/json") + resp, err := (&http.Client{}).Do(req) + if err != nil { + t.Fatalf("Error sending request: %s", err) + } + + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Error eceiving response body: %s", err) + } + + if resp.StatusCode != expectedCode { + t.Fatalf("Expected status code %d, but got %d.\nBody was: %s", expectedCode, resp.StatusCode, respBody) + } + + if expectedCode == 200 { + var data interface{} + err = json.Unmarshal(respBody, &data) + if err != nil { + t.Fatalf("Error parsing JSON response: %s", err) + } + + return data + } + return nil +} + +// Test the /denylist HTTP operations +func TestDenyList(t *testing.T) { + // GET empty list of rules + data := doRequest("GET", "/denylist", t, 200) + if !reflect.DeepEqual(data, []interface{}{}) { + t.Fatalf("Expected empty list from blank GET, but got %#v", data) + } + // PUT new rule + doRequest("PUT", "/denylist/abc", t, 201) + // GET list with new rule in it + data = doRequest("GET", "/denylist", t, 200) + if !reflect.DeepEqual(data, []interface{}{"abc"}) { + t.Fatalf("Expected singleton from GET, but got %#v", data) + } + // GET existing rule + data = doRequest("GET", "/denylist/abc", t, 200) + if !reflect.DeepEqual(data, "abc") { + t.Fatalf("Expected matched body from GET, but got %#v", data) + } + // PUT second rule + doRequest("PUT", "/denylist/def", t, 201) + // GET second rule + data = doRequest("GET", "/denylist/def", t, 200) + if !reflect.DeepEqual(data, "def") { + t.Fatalf("Expected matched body from GET, but got %#v", data) + } + // GET list with both rules + data = doRequest("GET", "/denylist", t, 200) + // check both permutations, in case the server reordered them + if !reflect.DeepEqual(data, []interface{}{"abc", "def"}) && !reflect.DeepEqual(data, []interface{}{"def", "abc"}) { + t.Fatalf("Expected doubleton from GET, but got %#v", data) + } + // DELETE first rule + doRequest("DELETE", "/denylist/abc", t, 204) + // GET first rule + doRequest("GET", "/denylist/abc", t, 404) + // GET list with only second rule + data = doRequest("GET", "/denylist", t, 200) + if !reflect.DeepEqual(data, []interface{}{"def"}) { + t.Fatalf("Expected singleton from GET, but got %#V", data) + } +} diff --git a/integration-tests/acceptance/denylist_oplog_test.go b/integration-tests/acceptance/denylist_oplog_test.go new file mode 100644 index 00000000..45acbafc --- /dev/null +++ b/integration-tests/acceptance/denylist_oplog_test.go @@ -0,0 +1,72 @@ +package main + +import ( + "context" + "testing" + + "github.com/tulip/oplogtoredis/integration-tests/helpers" + "go.mongodb.org/mongo-driver/bson" +) + +func TestDenyOplog(t *testing.T) { + harness := startHarness() + defer harness.stop() + + _, err := harness.mongoClient.Collection("Foo").InsertOne(context.Background(), bson.M{ + "_id": "id1", + "f": "1", + }) + if err != nil { + panic(err) + } + + expectedMessage1 := helpers.OTRMessage{ + Event: "i", + Document: map[string]interface{}{ + "_id": "id1", + }, + Fields: []string{"_id", "f"}, + } + + harness.verify(t, map[string][]helpers.OTRMessage{ + "tests.Foo": {expectedMessage1}, + "tests.Foo::id1": {expectedMessage1}, + }) + + doRequest("PUT", "/denylist/tests", t, 201) + + _, err = harness.mongoClient.Collection("Foo").InsertOne(context.Background(), bson.M{ + "_id": "id2", + "g": "2", + }) + if err != nil { + panic(err) + } + + // second message should not have been received, since it got denied + harness.verify(t, map[string][]helpers.OTRMessage{}) + + doRequest("DELETE", "/denylist/tests", t, 204) + + _, err = harness.mongoClient.Collection("Foo").InsertOne(context.Background(), bson.M{ + "_id": "id3", + "h": "3", + }) + if err != nil { + panic(err) + } + + expectedMessage3 := helpers.OTRMessage{ + Event: "i", + Document: map[string]interface{}{ + "_id": "id3", + }, + Fields: []string{"_id", "h"}, + } + + // back to normal now that the deny rule is gone + harness.verify(t, map[string][]helpers.OTRMessage{ + "tests.Foo": {expectedMessage3}, + "tests.Foo::id3": {expectedMessage3}, + }) +} diff --git a/lib/config/main.go b/lib/config/main.go index 0dec671d..fc94c324 100644 --- a/lib/config/main.go +++ b/lib/config/main.go @@ -4,8 +4,9 @@ package config import ( - "time" "strings" + "time" + "github.com/kelseyhightower/envconfig" ) @@ -21,6 +22,7 @@ type oplogtoredisConfiguration struct { MongoConnectTimeout time.Duration `default:"10s" split_words:"true"` MongoQueryTimeout time.Duration `default:"5s" split_words:"true"` OplogV2ExtractSubfieldChanges bool `default:"false" envconfig:"OPLOG_V2_EXTRACT_SUBFIELD_CHANGES"` + WriteParallelism int `default:"1" split_words:"true"` } var globalConfig *oplogtoredisConfiguration @@ -131,6 +133,13 @@ func OplogV2ExtractSubfieldChanges() bool { return globalConfig.OplogV2ExtractSubfieldChanges } +// WriteParallelism controls how many parallel write loops will be run (sharded based on a hash +// of the database name.) Each parallel loop has its own redis connection and internal buffer. +// Healthz endpoint will report fail if anyone of them dies. +func WriteParallelism() int { + return globalConfig.WriteParallelism +} + // ParseEnv parses the current environment variables and updates the stored // configuration. It is *not* threadsafe, and should just be called once // at the start of the program. diff --git a/lib/denylist/http.go b/lib/denylist/http.go new file mode 100644 index 00000000..9086e730 --- /dev/null +++ b/lib/denylist/http.go @@ -0,0 +1,97 @@ +package denylist + +import ( + "encoding/json" + "net/http" + "sync" +) + +// CollectionEndpoint serves the endpoints for the whole Denylist at /denylist +func CollectionEndpoint(denylist *sync.Map) func(http.ResponseWriter, *http.Request) { + return func(response http.ResponseWriter, request *http.Request) { + switch request.Method { + case "GET": + listDenylistKeys(response, denylist) + default: + http.Error(response, http.StatusText(http.StatusNotFound), http.StatusNotFound) + } + } +} + +// SingleEndpoint serves the endpoints for particular Denylist entries at /denylist/... +func SingleEndpoint(denylist *sync.Map) func(http.ResponseWriter, *http.Request) { + return func(response http.ResponseWriter, request *http.Request) { + switch request.Method { + case "GET": + getDenylistEntry(response, request, denylist) + case "PUT": + createDenylistEntry(response, request, denylist) + case "DELETE": + deleteDenylistEntry(response, request, denylist) + default: + http.Error(response, http.StatusText(http.StatusNotFound), http.StatusNotFound) + } + } +} + +// GET /denylist +func listDenylistKeys(response http.ResponseWriter, denylist *sync.Map) { + keys := []interface{}{} + + denylist.Range(func(key interface{}, value interface{}) bool { + keys = append(keys, key) + return true + }) + + response.Header().Set("Content-Type", "application/json") + response.WriteHeader(http.StatusOK) + err := json.NewEncoder(response).Encode(keys) + if err != nil { + http.Error(response, "couldn't encode result", http.StatusInternalServerError) + return + } +} + +// GET /denylist/... +func getDenylistEntry(response http.ResponseWriter, request *http.Request, denylist *sync.Map) { + id := request.URL.Path + _, exists := denylist.Load(id) + if !exists { + http.Error(response, "denylist entry not found with that id", http.StatusNotFound) + return + } + + response.Header().Set("Content-Type", "application/json") + response.WriteHeader(http.StatusOK) + err := json.NewEncoder(response).Encode(id) + if err != nil { + http.Error(response, "couldn't encode result", http.StatusInternalServerError) + return + } +} + +// PUT /denylist/... +func createDenylistEntry(response http.ResponseWriter, request *http.Request, denylist *sync.Map) { + id := request.URL.Path + _, exists := denylist.Load(id) + if exists { + response.WriteHeader(http.StatusNoContent) + return + } + + denylist.Store(id, true) + response.WriteHeader(http.StatusCreated) +} + +// DELETE /denylist/... +func deleteDenylistEntry(response http.ResponseWriter, request *http.Request, denylist *sync.Map) { + id := request.URL.Path + _, exists := denylist.Load(id) + if !exists { + http.Error(response, "denylist entry not found with that id", http.StatusNotFound) + return + } + + denylist.Delete(id) + response.WriteHeader(http.StatusNoContent) +} diff --git a/lib/oplog/processor.go b/lib/oplog/processor.go index 6484f871..b9ebda7b 100644 --- a/lib/oplog/processor.go +++ b/lib/oplog/processor.go @@ -1,6 +1,9 @@ package oplog import ( + "bytes" + "crypto/sha256" + "encoding/binary" "encoding/json" "strings" @@ -78,6 +81,16 @@ func processOplogEntry(op *oplogEntry) (*redispub.Publication, error) { return nil, errors.Wrap(err, "marshalling outgoing message") } + hash := sha256.Sum256([]byte(op.Database)) + intSlice := hash[len(hash)-8:] + + var hashInt uint64 + + err = binary.Read(bytes.NewReader(intSlice), binary.LittleEndian, &hashInt) + if err != nil { + panic(errors.Wrap(err, "decoding database hash as uint64")) + } + // We need to publish on both the full-collection channel and the // single-document channel return &redispub.Publication{ @@ -92,7 +105,8 @@ func processOplogEntry(op *oplogEntry) (*redispub.Publication, error) { Msg: msgJSON, OplogTimestamp: op.Timestamp, - TxIdx: op.TxIdx, + TxIdx: op.TxIdx, + ParallelismKey: int(hashInt), }, nil } diff --git a/lib/oplog/processor_test.go b/lib/oplog/processor_test.go index 09d76efd..95079b79 100644 --- a/lib/oplog/processor_test.go +++ b/lib/oplog/processor_test.go @@ -14,6 +14,9 @@ import ( "go.mongodb.org/mongo-driver/bson" ) +// hash of the database name "foo" to be expected for ParallelismKey +const fooHash = -5843589418109203719 + // nolint: gocyclo func TestProcessOplogEntry(t *testing.T) { // We can't compare raw publications because they contain JSON that can @@ -25,9 +28,10 @@ func TestProcessOplogEntry(t *testing.T) { Fields []string `json:"f"` } type decodedPublication struct { - Channels []string - Msg decodedPublicationMessage - OplogTimestamp primitive.Timestamp + Channels []string + Msg decodedPublicationMessage + OplogTimestamp primitive.Timestamp + ParallelismKey int } testObjectId, err := primitive.ObjectIDFromHex("deadbeefdeadbeefdeadbeef") @@ -67,6 +71,7 @@ func TestProcessOplogEntry(t *testing.T) { Fields: []string{"some"}, }, OplogTimestamp: primitive.Timestamp{T: 1234}, + ParallelismKey: fooHash, }, }, "Replacement update": { @@ -92,6 +97,7 @@ func TestProcessOplogEntry(t *testing.T) { Fields: []string{"some", "new"}, }, OplogTimestamp: primitive.Timestamp{T: 1234}, + ParallelismKey: fooHash, }, }, "Non-replacement update": { @@ -123,6 +129,7 @@ func TestProcessOplogEntry(t *testing.T) { Fields: []string{"a", "b", "c"}, }, OplogTimestamp: primitive.Timestamp{T: 1234}, + ParallelismKey: fooHash, }, }, "Delete": { @@ -145,6 +152,7 @@ func TestProcessOplogEntry(t *testing.T) { Fields: []string{}, }, OplogTimestamp: primitive.Timestamp{T: 1234}, + ParallelismKey: fooHash, }, }, "ObjectID id": { @@ -172,6 +180,7 @@ func TestProcessOplogEntry(t *testing.T) { Fields: []string{"some"}, }, OplogTimestamp: primitive.Timestamp{T: 1234}, + ParallelismKey: fooHash, }, }, "Unsupported id type": { @@ -242,9 +251,10 @@ func TestProcessOplogEntry(t *testing.T) { sort.Strings(msg.Fields) return &decodedPublication{ - Channels: pub.Channels, - Msg: msg, - OplogTimestamp: pub.OplogTimestamp, + Channels: pub.Channels, + Msg: msg, + OplogTimestamp: pub.OplogTimestamp, + ParallelismKey: pub.ParallelismKey, } } diff --git a/lib/oplog/tail.go b/lib/oplog/tail.go index 0f62b032..b14773ca 100644 --- a/lib/oplog/tail.go +++ b/lib/oplog/tail.go @@ -7,6 +7,7 @@ import ( "context" "errors" "strings" + "sync" "time" "github.com/tulip/oplogtoredis/lib/config" @@ -25,10 +26,11 @@ import ( // Tailer persistently tails the oplog of a Mongo cluster, handling // reconnection and resumption of where it left off. type Tailer struct { - MongoClient *mongo.Client + MongoClient *mongo.Client RedisClients []redis.UniversalClient - RedisPrefix string - MaxCatchUp time.Duration + RedisPrefix string + MaxCatchUp time.Duration + Denylist *sync.Map } // Raw oplog entry from Mongo @@ -93,10 +95,10 @@ var ( }, }, []string{"database", "status"}) - metricLastOplogEntryStaleness = promauto.NewGauge(prometheus.GaugeOpts{ + metricLastReceivedStaleness = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: "otr", Subsystem: "oplog", - Name: "last_entry_staleness_seconds", + Name: "last_received_staleness", Help: "Gauge recording the difference between this server's clock and the timestamp on the last read oplog entry.", }) ) @@ -107,7 +109,7 @@ func init() { // Tail begins tailing the oplog. It doesn't return unless it receives a message // on the stop channel, in which case it wraps up its work and then returns. -func (tailer *Tailer) Tail(out chan<- *redispub.Publication, stop <-chan bool) { +func (tailer *Tailer) Tail(out []chan<- *redispub.Publication, stop <-chan bool) { childStopC := make(chan bool) wasStopped := false @@ -131,7 +133,9 @@ func (tailer *Tailer) Tail(out chan<- *redispub.Publication, stop <-chan bool) { } } -func (tailer *Tailer) tailOnce(out chan<- *redispub.Publication, stop <-chan bool) { +func (tailer *Tailer) tailOnce(out []chan<- *redispub.Publication, stop <-chan bool) { + parallelismSize := len(out) + session, err := tailer.MongoClient.StartSession() if err != nil { log.Log.Errorw("Failed to start Mongo session", "error", err) @@ -140,7 +144,7 @@ func (tailer *Tailer) tailOnce(out chan<- *redispub.Publication, stop <-chan boo oplogCollection := session.Client().Database("local").Collection("oplog.rs") - startTime := tailer.getStartTime(func() (*primitive.Timestamp, error) { + startTime := tailer.getStartTime(parallelismSize-1, func() (*primitive.Timestamp, error) { // Get the timestamp of the last entry in the oplog (as a position to // start from if we don't have a last-written timestamp from Redis) var entry rawOplogEntry @@ -195,7 +199,7 @@ func (tailer *Tailer) tailOnce(out chan<- *redispub.Publication, stop <-chan boo continue } - ts, pubs := tailer.unmarshalEntry(rawData) + ts, pubs := tailer.unmarshalEntry(rawData, tailer.Denylist) if ts != nil { lastTimestamp = *ts @@ -203,7 +207,8 @@ func (tailer *Tailer) tailOnce(out chan<- *redispub.Publication, stop <-chan boo for _, pub := range pubs { if pub != nil { - out <- pub + outIdx := (pub.ParallelismKey%parallelismSize + parallelismSize) % parallelismSize + out[outIdx] <- pub } else { log.Log.Error("Nil Redis publication") } @@ -328,7 +333,7 @@ func closeCursor(cursor *mongo.Cursor) { // // The timestamp of the entry is returned so that tailOnce knows the timestamp of the last entry it read, even if it // ignored it or failed at some later step. -func (tailer *Tailer) unmarshalEntry(rawData bson.Raw) (timestamp *primitive.Timestamp, pubs []*redispub.Publication) { +func (tailer *Tailer) unmarshalEntry(rawData bson.Raw, denylist *sync.Map) (timestamp *primitive.Timestamp, pubs []*redispub.Publication) { var result rawOplogEntry err := bson.Unmarshal(rawData, &result) @@ -345,6 +350,7 @@ func (tailer *Tailer) unmarshalEntry(rawData bson.Raw) (timestamp *primitive.Tim status := "ignored" database := "(no database)" messageLen := float64(len(rawData)) + metricLastReceivedStaleness.Set(float64(time.Since(time.Unix(int64(timestamp.T), 0)))) defer func() { // TODO: remove these in a future version @@ -353,13 +359,17 @@ func (tailer *Tailer) unmarshalEntry(rawData bson.Raw) (timestamp *primitive.Tim metricOplogEntriesBySize.WithLabelValues(database, status).Observe(messageLen) metricMaxOplogEntryByMinute.Report(messageLen, database, status) - metricLastOplogEntryStaleness.Set(float64(time.Since(time.Unix(int64(timestamp.T), 0)))) }() if len(entries) > 0 { database = entries[0].Database } + if _, denied := denylist.Load(database); denied { + log.Log.Debugw("Skipping oplog entry", "database", database) + return + } + type errEntry struct { err error op *oplogEntry @@ -403,8 +413,9 @@ func (tailer *Tailer) unmarshalEntry(rawData bson.Raw) (timestamp *primitive.Tim // We take the function to get the timestamp of the last oplog entry (as a // fallback if we don't have a latest timestamp from Redis) as an arg instead // of using tailer.mongoClient directly so we can unit test this function -func (tailer *Tailer) getStartTime(getTimestampOfLastOplogEntry func() (*primitive.Timestamp, error)) primitive.Timestamp { - ts, tsTime, redisErr := redispub.LastProcessedTimestamp(tailer.RedisClients[0], tailer.RedisPrefix) +func (tailer *Tailer) getStartTime(maxOrdinal int, getTimestampOfLastOplogEntry func() (*primitive.Timestamp, error)) primitive.Timestamp { + // Get the earliest "last processed time" for each shard. This assumes that the number of shards is constant. + ts, tsTime, redisErr := redispub.FirstLastProcessedTimestamp(tailer.RedisClients[0], tailer.RedisPrefix, maxOrdinal) if redisErr == nil { // we have a last write time, check that it's not too far in the diff --git a/lib/oplog/tail_test.go b/lib/oplog/tail_test.go index 8abcca44..651e346d 100644 --- a/lib/oplog/tail_test.go +++ b/lib/oplog/tail_test.go @@ -3,6 +3,7 @@ package oplog import ( "errors" "strconv" + "sync" "testing" "time" @@ -73,7 +74,7 @@ func TestGetStartTime(t *testing.T) { panic(err) } defer redisServer.Close() - require.NoError(t, redisServer.Set("someprefix.lastProcessedEntry", strconv.FormatInt(int64(test.redisTimestamp.T), 10))) + require.NoError(t, redisServer.Set("someprefix.lastProcessedEntry.0", strconv.FormatInt(int64(test.redisTimestamp.T), 10))) redisClient := []redis.UniversalClient{redis.NewUniversalClient(&redis.UniversalOptions{ Addrs: []string{redisServer.Addr()}, @@ -81,11 +82,12 @@ func TestGetStartTime(t *testing.T) { tailer := Tailer{ RedisClients: redisClient, - RedisPrefix: "someprefix.", - MaxCatchUp: maxCatchUp, + RedisPrefix: "someprefix.", + MaxCatchUp: maxCatchUp, + Denylist: &sync.Map{}, } - actualResult := tailer.getStartTime(func() (*primitive.Timestamp, error) { + actualResult := tailer.getStartTime(0, func() (*primitive.Timestamp, error) { if test.mongoEndOfOplogErr != nil { return nil, test.mongoEndOfOplogErr } @@ -285,7 +287,7 @@ func TestParseRawOplogEntry(t *testing.T) { for testName, test := range tests { t.Run(testName, func(t *testing.T) { - got := (&Tailer{}).parseRawOplogEntry(test.in, nil) + got := (&Tailer{Denylist: &sync.Map{}}).parseRawOplogEntry(test.in, nil) if diff := pretty.Compare(got, test.want); diff != "" { t.Errorf("Got incorrect result (-got +want)\n%s", diff) diff --git a/lib/redispub/lastProcessedTime.go b/lib/redispub/lastProcessedTime.go index dbfa44ba..f320b641 100644 --- a/lib/redispub/lastProcessedTime.go +++ b/lib/redispub/lastProcessedTime.go @@ -2,6 +2,7 @@ package redispub import ( "context" + "strconv" "time" "github.com/go-redis/redis/v8" @@ -17,8 +18,8 @@ import ( // // If oplogtoredis has not processed any messages, returns redis.Nil as an // error. -func LastProcessedTimestamp(redisClient redis.UniversalClient, metadataPrefix string) (primitive.Timestamp, time.Time, error) { - str, err := redisClient.Get(context.Background(), metadataPrefix+"lastProcessedEntry").Result() +func LastProcessedTimestamp(redisClient redis.UniversalClient, metadataPrefix string, ordinal int) (primitive.Timestamp, time.Time, error) { + str, err := redisClient.Get(context.Background(), metadataPrefix+"lastProcessedEntry."+strconv.Itoa(ordinal)).Result() if err != nil { return primitive.Timestamp{}, time.Unix(0, 0), err } @@ -31,3 +32,22 @@ func LastProcessedTimestamp(redisClient redis.UniversalClient, metadataPrefix st time := mongoTimestampToTime(ts) return ts, time, nil } + +// FirstLastProcessedTimestamp runs LastProcessedTimestamp for each ordinal up to the provided count, +// then returns the earliest such timestamp obtained. If any ordinal produces an error, that error is returned. +func FirstLastProcessedTimestamp(redisClient redis.UniversalClient, metadataPrefix string, maxOrdinal int) (primitive.Timestamp, time.Time, error) { + var minTs primitive.Timestamp + var minTime time.Time + for i := 0; i <= maxOrdinal; i++ { + ts, time, err := LastProcessedTimestamp(redisClient, metadataPrefix, i) + if err != nil { + return ts, time, err + } + + if i == 0 || ts.Before(minTs) { + minTs = ts + minTime = time + } + } + return minTs, minTime, nil +} diff --git a/lib/redispub/lastProcessedTime_test.go b/lib/redispub/lastProcessedTime_test.go index ffb4b185..acd689dc 100644 --- a/lib/redispub/lastProcessedTime_test.go +++ b/lib/redispub/lastProcessedTime_test.go @@ -31,9 +31,9 @@ func TestLastProcessedTimestampSuccess(t *testing.T) { redisServer, redisClient := startMiniredis() defer redisServer.Close() - require.NoError(t, redisServer.Set("someprefix.lastProcessedEntry", encodeMongoTimestamp(nowTS))) + require.NoError(t, redisServer.Set("someprefix.lastProcessedEntry.0", encodeMongoTimestamp(nowTS))) - gotTS, gotTime, err := LastProcessedTimestamp(redisClient, "someprefix.") + gotTS, gotTime, err := LastProcessedTimestamp(redisClient, "someprefix.", 0) if err != nil { t.Errorf("Got unexpected error: %s", err) @@ -52,7 +52,7 @@ func TestLastProcessedTimestampNoRecord(t *testing.T) { redisServer, redisClient := startMiniredis() defer redisServer.Close() - _, _, err := LastProcessedTimestamp(redisClient, "someprefix.") + _, _, err := LastProcessedTimestamp(redisClient, "someprefix.", 0) if err == nil { t.Errorf("Expected redis.Nil error, got no error") @@ -64,9 +64,9 @@ func TestLastProcessedTimestampNoRecord(t *testing.T) { func TestLastProcessedTimestampInvalidRecord(t *testing.T) { redisServer, redisClient := startMiniredis() defer redisServer.Close() - require.NoError(t, redisServer.Set("someprefix.lastProcessedEntry", "not a number")) + require.NoError(t, redisServer.Set("someprefix.lastProcessedEntry.0", "not a number")) - _, _, err := LastProcessedTimestamp(redisClient, "someprefix.") + _, _, err := LastProcessedTimestamp(redisClient, "someprefix.", 0) if err == nil { t.Errorf("Expected strconv error, got no error") @@ -80,7 +80,7 @@ func TestLastProcessedTimestampRedisError(t *testing.T) { Addrs: []string{"not a server"}, }) - _, _, err := LastProcessedTimestamp(redisClient, "someprefix.") + _, _, err := LastProcessedTimestamp(redisClient, "someprefix.", 0) if err == nil { t.Errorf("Expected TCP error, got no error") diff --git a/lib/redispub/publication.go b/lib/redispub/publication.go index 156196a2..babc02c6 100644 --- a/lib/redispub/publication.go +++ b/lib/redispub/publication.go @@ -20,4 +20,8 @@ type Publication struct { // TxIdx is the index of the operation within a transaction. Used to supplement OplogTimestamp in a transaction. TxIdx uint + + // ParallelismKey is a number representing which parallel write loop will process this message. + // It is a hash of the database name, assuming that a single database is the unit of ordering guarantee. + ParallelismKey int } diff --git a/lib/redispub/publisher.go b/lib/redispub/publisher.go index f6363eac..6236c588 100644 --- a/lib/redispub/publisher.go +++ b/lib/redispub/publisher.go @@ -7,6 +7,7 @@ package redispub import ( "context" "fmt" + "strconv" "strings" "time" @@ -61,28 +62,43 @@ var metricLastCommandDuration = promauto.NewGauge(prometheus.GaugeOpts{ Help: "The round trip time in seconds of the most recent write to Redis.", }) +var metricStalenessPreRetries = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: "otr", + Subsystem: "redispub", + Name: "pre_retry_staleness", + Help: "Gauge recording the staleness on receiving a message from the tailing routine.", +}, []string{"ordinal"}) + +var metricLastOplogEntryStaleness = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: "otr", + Subsystem: "redispub", + Name: "last_entry_staleness_seconds", + Help: "Gauge recording the difference between this server's clock and the timestamp on the last published oplog entry.", +}, []string{"ordinal"}) + // PublishStream reads Publications from the given channel and publishes them // to Redis. -func PublishStream(clients []redis.UniversalClient, in <-chan *Publication, opts *PublishOpts, stop <-chan bool) { +func PublishStream(clients []redis.UniversalClient, in <-chan *Publication, opts *PublishOpts, stop <-chan bool, ordinal int) { + // Start up a background goroutine for periodically updating the last-processed // timestamp timestampC := make(chan primitive.Timestamp) - for _,client := range clients { - go periodicallyUpdateTimestamp(client, timestampC, opts) + for _, client := range clients { + go periodicallyUpdateTimestamp(client, timestampC, opts, ordinal) } // Redis expiration is in integer seconds, so we have to convert the // time.Duration dedupeExpirationSeconds := int(opts.DedupeExpiration.Seconds()) - type PubFn func(*Publication)error + type PubFn func(*Publication) error var publishFns []PubFn - for _,client := range clients { + for _, client := range clients { client := client publishFn := func(p *Publication) error { - return publishSingleMessage(p, client, opts.MetadataPrefix, dedupeExpirationSeconds) + return publishSingleMessage(p, client, opts.MetadataPrefix, dedupeExpirationSeconds, ordinal) } publishFns = append(publishFns, publishFn) } @@ -97,11 +113,11 @@ func PublishStream(clients []redis.UniversalClient, in <-chan *Publication, opts return case p := <-in: - for i,publishFn := range publishFns { + metricStalenessPreRetries.WithLabelValues(strconv.Itoa(ordinal)).Set(float64(time.Since(time.Unix(int64(p.OplogTimestamp.T), 0)).Seconds())) + for i, publishFn := range publishFns { err := publishSingleMessageWithRetries(p, 30, time.Second, publishFn) log.Log.Debugw("Published to", "idx", i) - if err != nil { metricSendFailed.Inc() log.Log.Errorw("Permanent error while trying to publish message; giving up", @@ -146,8 +162,9 @@ func publishSingleMessageWithRetries(p *Publication, maxRetries int, sleepTime t return errors.Errorf("sending message (retried %v times)", maxRetries) } -func publishSingleMessage(p *Publication, client redis.UniversalClient, prefix string, dedupeExpirationSeconds int) error { +func publishSingleMessage(p *Publication, client redis.UniversalClient, prefix string, dedupeExpirationSeconds int, ordinal int) error { start := time.Now() + metricLastOplogEntryStaleness.WithLabelValues(strconv.Itoa(ordinal)).Set(float64(time.Since(time.Unix(int64(p.OplogTimestamp.T), 0)).Seconds())) _, err := publishDedupe.Run( context.Background(), @@ -181,14 +198,14 @@ func formatKey(p *Publication, prefix string) string { // channel, and this function throttles that to only update occasionally. // // This blocks forever; it should be run in a goroutine -func periodicallyUpdateTimestamp(client redis.UniversalClient, timestamps <-chan primitive.Timestamp, opts *PublishOpts) { +func periodicallyUpdateTimestamp(client redis.UniversalClient, timestamps <-chan primitive.Timestamp, opts *PublishOpts, ordinal int) { var lastFlush time.Time var mostRecentTimestamp primitive.Timestamp var needFlush bool flush := func() { if needFlush { - client.Set(context.Background(), opts.MetadataPrefix+"lastProcessedEntry", encodeMongoTimestamp(mostRecentTimestamp), 0) + client.Set(context.Background(), opts.MetadataPrefix+"lastProcessedEntry."+strconv.Itoa(ordinal), encodeMongoTimestamp(mostRecentTimestamp), 0) lastFlush = time.Now() needFlush = false } diff --git a/lib/redispub/publisher_test.go b/lib/redispub/publisher_test.go index 0d7f39fa..66553f22 100644 --- a/lib/redispub/publisher_test.go +++ b/lib/redispub/publisher_test.go @@ -121,11 +121,11 @@ func TestPeriodicallyUpdateTimestamp(t *testing.T) { periodicallyUpdateTimestamp(redisClient, timestampC, &PublishOpts{ MetadataPrefix: "someprefix.", FlushInterval: testSpeed, - }) + }, 0) waitGroup.Done() }() - key := "someprefix.lastProcessedEntry" + key := "someprefix.lastProcessedEntry.0" // Key should be unset if redisServer.Exists(key) { diff --git a/main.go b/main.go index 7eb68144..6bce54a6 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "encoding/json" + "fmt" stdlog "log" "net/http" "net/http/pprof" @@ -17,6 +18,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/readpref" "github.com/tulip/oplogtoredis/lib/config" + "github.com/tulip/oplogtoredis/lib/denylist" "github.com/tulip/oplogtoredis/lib/log" "github.com/tulip/oplogtoredis/lib/oplog" "github.com/tulip/oplogtoredis/lib/parse" @@ -53,74 +55,93 @@ func main() { }() log.Log.Info("Initialized connection to Mongo") - redisClients, err := createRedisClients() - if err != nil { - panic("Error initializing Redis client: " + err.Error()) - } - defer func() { - for _, redisClient := range redisClients { - redisCloseErr := redisClient.Close() - if redisCloseErr != nil { - log.Log.Errorw("Error closing Redis client", - "error", redisCloseErr) - } - } - }() - log.Log.Info("Initialized connection to Redis") + parallelism := config.WriteParallelism() + aggregatedRedisClients := make([][]redis.UniversalClient, parallelism) + aggregatedRedisPubs := make([]chan<- *redispub.Publication, parallelism) + stopRedisPubs := make([]chan bool, parallelism) - // We crate two goroutines: - // - // The oplog.Tail goroutine reads messages from the oplog, and generates the - // messages that we need to write to redis. It then writes them to a - // buffered channel. - // - // The redispub.PublishStream goroutine reads messages from the buffered channel - // and sends them to Redis. - // - // TODO PERF: Use a leaky buffer (https://github.com/tulip/oplogtoredis/issues/2) bufferSize := 10000 - redisPubs := make(chan *redispub.Publication, bufferSize) + waitGroup := sync.WaitGroup{} + denylist := sync.Map{} + + for i := 0; i < config.WriteParallelism(); i++ { + redisClients, err := createRedisClients() + if err != nil { + panic(fmt.Sprintf("[%d] Error initializing Redis client: %s", i, err.Error())) + } + defer func() { + for _, redisClient := range redisClients { + redisCloseErr := redisClient.Close() + if redisCloseErr != nil { + log.Log.Errorw("Error closing Redis client", + "error", redisCloseErr, + "i", i) + } + } + }() + log.Log.Infow("Initialized connection to Redis", "i", i) + + aggregatedRedisClients[i] = redisClients + + // We crate two goroutines: + // + // The oplog.Tail goroutine reads messages from the oplog, and generates the + // messages that we need to write to redis. It then writes them to a + // buffered channel. + // + // The redispub.PublishStream goroutine reads messages from the buffered channel + // and sends them to Redis. + // + // TODO PERF: Use a leaky buffer (https://github.com/tulip/oplogtoredis/issues/2) + redisPubs := make(chan *redispub.Publication, bufferSize) + aggregatedRedisPubs[i] = redisPubs + + stopRedisPub := make(chan bool) + waitGroup.Add(1) + go func(ordinal int) { + redispub.PublishStream(redisClients, redisPubs, &redispub.PublishOpts{ + FlushInterval: config.TimestampFlushInterval(), + DedupeExpiration: config.RedisDedupeExpiration(), + MetadataPrefix: config.RedisMetadataPrefix(), + }, stopRedisPub, ordinal) + log.Log.Infow("Redis publisher completed", "i", i) + waitGroup.Done() + }(i) + log.Log.Info("Started up processing goroutines") + stopRedisPubs[i] = stopRedisPub + } promauto.NewGaugeFunc(prometheus.GaugeOpts{ Namespace: "otr", Name: "buffer_available", Help: "Gauge indicating the available space in the buffer of oplog entries waiting to be written to redis.", - }, func () float64 { - return float64(bufferSize - len(redisPubs)) + }, func() float64 { + total := 0 + for _, redisPubs := range aggregatedRedisPubs { + total += (bufferSize - len(redisPubs)) + } + return float64(total) }) - waitGroup := sync.WaitGroup{} - stopOplogTail := make(chan bool) waitGroup.Add(1) go func() { tailer := oplog.Tailer{ MongoClient: mongoSession, - RedisClients: redisClients, - RedisPrefix: config.RedisMetadataPrefix(), - MaxCatchUp: config.MaxCatchUp(), + RedisClients: aggregatedRedisClients[0], // the tailer coroutine needs a redis client for determining start timestamp + // it doesn't really matter which one since this isn't a meaningful amount of load, so just take the first one + RedisPrefix: config.RedisMetadataPrefix(), + MaxCatchUp: config.MaxCatchUp(), + Denylist: &denylist, } - tailer.Tail(redisPubs, stopOplogTail) + tailer.Tail(aggregatedRedisPubs, stopOplogTail) log.Log.Info("Oplog tailer completed") waitGroup.Done() }() - stopRedisPub := make(chan bool) - waitGroup.Add(1) - go func() { - redispub.PublishStream(redisClients, redisPubs, &redispub.PublishOpts{ - FlushInterval: config.TimestampFlushInterval(), - DedupeExpiration: config.RedisDedupeExpiration(), - MetadataPrefix: config.RedisMetadataPrefix(), - }, stopRedisPub) - log.Log.Info("Redis publisher completed") - waitGroup.Done() - }() - log.Log.Info("Started up processing goroutines") - // Start one more goroutine for the HTTP server - httpServer := makeHTTPServer(redisClients, mongoSession) + httpServer := makeHTTPServer(aggregatedRedisClients, mongoSession, &denylist) go func() { httpErr := httpServer.ListenAndServe() if httpErr != nil { @@ -147,7 +168,9 @@ func main() { signal.Reset() stopOplogTail <- true - stopRedisPub <- true + for _, stopRedisPub := range stopRedisPubs { + stopRedisPub <- true + } err = httpServer.Shutdown(context.Background()) if err != nil { @@ -226,17 +249,19 @@ func createRedisClients() ([]redis.UniversalClient, error) { return ret, nil } -func makeHTTPServer(clients []redis.UniversalClient, mongo *mongo.Client) *http.Server { +func makeHTTPServer(aggregatedClients [][]redis.UniversalClient, mongo *mongo.Client, denylistMap *sync.Map) *http.Server { mux := http.NewServeMux() mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { redisOK := true - for _, redis := range clients { - redisErr := redis.Ping(context.Background()).Err() - redisOK = (redisOK && (redisErr == nil)) - if !redisOK { - log.Log.Errorw("Error connecting to Redis during healthz check", - "error", redisErr) + for _, clients := range aggregatedClients { + for _, redis := range clients { + redisErr := redis.Ping(context.Background()).Err() + redisOK = (redisOK && (redisErr == nil)) + if !redisOK { + log.Log.Errorw("Error connecting to Redis during healthz check", + "error", redisErr) + } } } @@ -276,5 +301,8 @@ func makeHTTPServer(clients []redis.UniversalClient, mongo *mongo.Client) *http. mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) mux.HandleFunc("/debug/pprof/trace", pprof.Trace) + mux.HandleFunc("/denylist", denylist.CollectionEndpoint(denylistMap)) + mux.Handle("/denylist/", http.StripPrefix("/denylist/", http.HandlerFunc(denylist.SingleEndpoint(denylistMap)))) + return &http.Server{Addr: config.HTTPServerAddr(), Handler: mux} }