From 62607323474f59b5971fa1566c9e55f7f6cae5d9 Mon Sep 17 00:00:00 2001 From: Uday Patil Date: Wed, 22 Nov 2023 13:03:37 -0500 Subject: [PATCH] [occ] OCC scheduler and validation fixes (#359) ## Describe your changes and provide context This makes optimizations to the scheduler and validation ## Testing performed to validate your change --------- Co-authored-by: Steven Landers --- store/cache/cache.go | 33 +++--- store/cachekv/store.go | 33 ++++-- store/multiversion/store.go | 13 ++- tasks/scheduler.go | 215 ++++++++++++++++++++++++------------ tasks/scheduler_test.go | 4 +- 5 files changed, 196 insertions(+), 102 deletions(-) diff --git a/store/cache/cache.go b/store/cache/cache.go index cbaeaeb86..b28675ba3 100644 --- a/store/cache/cache.go +++ b/store/cache/cache.go @@ -33,7 +33,7 @@ type ( // the same CommitKVStoreCache may be accessed concurrently by multiple // goroutines due to transaction parallelization - mtx sync.Mutex + mtx sync.RWMutex } // CommitKVStoreCacheManager maintains a mapping from a StoreKey to a @@ -102,27 +102,34 @@ func (ckv *CommitKVStoreCache) CacheWrap(storeKey types.StoreKey) types.CacheWra return cachekv.NewStore(ckv, storeKey, ckv.cacheKVSize) } +// getFromCache queries the write-through cache for a value by key. +func (ckv *CommitKVStoreCache) getFromCache(key []byte) ([]byte, bool) { + ckv.mtx.RLock() + defer ckv.mtx.RUnlock() + return ckv.cache.Get(string(key)) +} + +// getAndWriteToCache queries the underlying CommitKVStore and writes the result +func (ckv *CommitKVStoreCache) getAndWriteToCache(key []byte) []byte { + ckv.mtx.Lock() + defer ckv.mtx.Unlock() + value := ckv.CommitKVStore.Get(key) + ckv.cache.Add(string(key), value) + return value +} + // Get retrieves a value by key. It will first look in the write-through cache. // If the value doesn't exist in the write-through cache, the query is delegated // to the underlying CommitKVStore. func (ckv *CommitKVStoreCache) Get(key []byte) []byte { - ckv.mtx.Lock() - defer ckv.mtx.Unlock() - types.AssertValidKey(key) - keyStr := string(key) - value, ok := ckv.cache.Get(keyStr) - if ok { - // cache hit + if value, ok := ckv.getFromCache(key); ok { return value } - // cache miss; write to cache - value = ckv.CommitKVStore.Get(key) - ckv.cache.Add(keyStr, value) - - return value + // if not found in the cache, query the underlying CommitKVStore and init cache value + return ckv.getAndWriteToCache(key) } // Set inserts a key/value pair into both the write-through cache and the diff --git a/store/cachekv/store.go b/store/cachekv/store.go index f03ee517e..9a21b695c 100644 --- a/store/cachekv/store.go +++ b/store/cachekv/store.go @@ -56,7 +56,7 @@ func (b mapCacheBackend) Range(f func(string, *types.CValue) bool) { // Store wraps an in-memory cache around an underlying types.KVStore. type Store struct { - mtx sync.Mutex + mtx sync.RWMutex cache *types.BoundedCache deleted *sync.Map unsortedCache map[string]struct{} @@ -104,20 +104,33 @@ func (store *Store) GetStoreType() types.StoreType { return store.parent.GetStoreType() } -// Get implements types.KVStore. -func (store *Store) Get(key []byte) (value []byte) { +// getFromCache queries the write-through cache for a value by key. +func (store *Store) getFromCache(key []byte) ([]byte, bool) { + store.mtx.RLock() + defer store.mtx.RUnlock() + if cv, ok := store.cache.Get(conv.UnsafeBytesToStr(key)); ok { + return cv.Value(), true + } + return nil, false +} + +// getAndWriteToCache queries the underlying CommitKVStore and writes the result +func (store *Store) getAndWriteToCache(key []byte) []byte { store.mtx.Lock() defer store.mtx.Unlock() + value := store.parent.Get(key) + store.setCacheValue(key, value, false, false) + return value +} +// Get implements types.KVStore. +func (store *Store) Get(key []byte) (value []byte) { types.AssertValidKey(key) - cacheValue, ok := store.cache.Get(conv.UnsafeBytesToStr(key)) + value, ok := store.getFromCache(key) if !ok { // TODO: (occ) This is an example of when we fall through when we dont have a cache hit. Similarly, for mvkv, we'll try to serve reads from a local cache thats transient to the TX, and if its NOT present, then we read through AND mark the access (along with the value that was read) for validation - value = store.parent.Get(key) - store.setCacheValue(key, value, false, false) - } else { - value = cacheValue.Value() + value = store.getAndWriteToCache(key) } // TODO: (occ) This is an example of how we currently track accesses store.eventManager.EmitResourceAccessReadEvent("get", store.storeKey, key, value) @@ -140,8 +153,8 @@ func (store *Store) Set(key []byte, value []byte) { // Has implements types.KVStore. func (store *Store) Has(key []byte) bool { value := store.Get(key) - store.mtx.Lock() - defer store.mtx.Unlock() + store.mtx.RLock() + defer store.mtx.RUnlock() store.eventManager.EmitResourceAccessReadEvent("has", store.storeKey, key, value) return value != nil } diff --git a/store/multiversion/store.go b/store/multiversion/store.go index bc5e8ee4a..16fb04597 100644 --- a/store/multiversion/store.go +++ b/store/multiversion/store.go @@ -320,11 +320,11 @@ func (s *Store) validateIterator(index int, tracker iterationTracker) bool { } func (s *Store) checkIteratorAtIndex(index int) bool { - s.mtx.RLock() - defer s.mtx.RUnlock() - valid := true + s.mtx.RLock() iterateset := s.txIterateSets[index] + s.mtx.RUnlock() + for _, iterationTracker := range iterateset { iteratorValid := s.validateIterator(index, iterationTracker) valid = valid && iteratorValid @@ -333,11 +333,12 @@ func (s *Store) checkIteratorAtIndex(index int) bool { } func (s *Store) checkReadsetAtIndex(index int) (bool, []int) { - s.mtx.RLock() - defer s.mtx.RUnlock() - conflictSet := make(map[int]struct{}) + + s.mtx.RLock() readset := s.txReadSets[index] + s.mtx.RUnlock() + valid := true // iterate over readset and check if the value is the same as the latest value relateive to txIndex in the multiversion store diff --git a/tasks/scheduler.go b/tasks/scheduler.go index 7b1afc0d2..c00e70dbe 100644 --- a/tasks/scheduler.go +++ b/tasks/scheduler.go @@ -4,6 +4,7 @@ import ( "crypto/sha256" "fmt" "sort" + "sync" "github.com/tendermint/tendermint/abci/types" "go.opentelemetry.io/otel/attribute" @@ -38,7 +39,6 @@ const ( type deliverTxTask struct { Ctx sdk.Context - Span trace.Span AbortCh chan occ.Abort Status status @@ -49,10 +49,10 @@ type deliverTxTask struct { Request types.RequestDeliverTx Response *types.ResponseDeliverTx VersionStores map[sdk.StoreKey]*multiversion.VersionIndexedStore + ValidateCh chan struct{} } -func (dt *deliverTxTask) Increment() { - dt.Incarnation++ +func (dt *deliverTxTask) Reset() { dt.Status = statusPending dt.Response = nil dt.Abort = nil @@ -61,6 +61,11 @@ func (dt *deliverTxTask) Increment() { dt.VersionStores = nil } +func (dt *deliverTxTask) Increment() { + dt.Incarnation++ + dt.ValidateCh = make(chan struct{}, 1) +} + // Scheduler processes tasks concurrently type Scheduler interface { ProcessAll(ctx sdk.Context, reqs []*sdk.DeliverTxEntry) ([]types.ResponseDeliverTx, error) @@ -71,6 +76,7 @@ type scheduler struct { workers int multiVersionStores map[sdk.StoreKey]multiversion.MultiVersionStore tracingInfo *tracing.Info + allTasks []*deliverTxTask } // NewScheduler creates a new scheduler @@ -111,9 +117,10 @@ func toTasks(reqs []*sdk.DeliverTxEntry) []*deliverTxTask { res := make([]*deliverTxTask, 0, len(reqs)) for idx, r := range reqs { res = append(res, &deliverTxTask{ - Request: r.Request, - Index: idx, - Status: statusPending, + Request: r.Request, + Index: idx, + Status: statusPending, + ValidateCh: make(chan struct{}, 1), }) } return res @@ -175,6 +182,7 @@ func (s *scheduler) ProcessAll(ctx sdk.Context, reqs []*sdk.DeliverTxEntry) ([]t // prefill estimates s.PrefillEstimates(ctx, reqs) tasks := toTasks(reqs) + s.allTasks = tasks toExecute := tasks for !allValidated(tasks) { var err error @@ -193,9 +201,6 @@ func (s *scheduler) ProcessAll(ctx sdk.Context, reqs []*sdk.DeliverTxEntry) ([]t if err != nil { return nil, err } - for _, t := range toExecute { - t.Increment() - } } for _, mv := range s.multiVersionStores { mv.WriteLatestToStore() @@ -203,52 +208,83 @@ func (s *scheduler) ProcessAll(ctx sdk.Context, reqs []*sdk.DeliverTxEntry) ([]t return collectResponses(tasks), nil } -func (s *scheduler) validateAll(ctx sdk.Context, tasks []*deliverTxTask) ([]*deliverTxTask, error) { - spanCtx, span := s.tracingInfo.StartWithContext("SchedulerValidate", ctx.TraceSpanContext()) - ctx = ctx.WithTraceSpanContext(spanCtx) +func (s *scheduler) shouldRerun(task *deliverTxTask) bool { + switch task.Status { + + case statusAborted, statusPending: + return true + + // validated tasks can become unvalidated if an earlier re-run task now conflicts + case statusExecuted, statusValidated: + if valid, conflicts := s.findConflicts(task); !valid { + s.invalidateTask(task) + + // if the conflicts are now validated, then rerun this task + if indexesValidated(s.allTasks, conflicts) { + return true + } else { + // otherwise, wait for completion + task.Dependencies = conflicts + task.Status = statusWaiting + return false + } + } else if len(conflicts) == 0 { + // mark as validated, which will avoid re-validating unless a lower-index re-validates + task.Status = statusValidated + return false + } + // conflicts and valid, so it'll validate next time + return false + + case statusWaiting: + // if conflicts are done, then this task is ready to run again + return indexesValidated(s.allTasks, task.Dependencies) + } + panic("unexpected status: " + task.Status) +} + +func (s *scheduler) validateTask(ctx sdk.Context, task *deliverTxTask) bool { + _, span := s.traceSpan(ctx, "SchedulerValidate", task) defer span.End() - var res []*deliverTxTask + if s.shouldRerun(task) { + return false + } + return true +} - // find first non-validated entry - var startIdx int - for idx, t := range tasks { +func (s *scheduler) findFirstNonValidated() (int, bool) { + for i, t := range s.allTasks { if t.Status != statusValidated { - startIdx = idx - break + return i, true } } + return 0, false +} - for i := startIdx; i < len(tasks); i++ { - switch tasks[i].Status { - case statusAborted: - // aborted means it can be re-run immediately - res = append(res, tasks[i]) - - // validated tasks can become unvalidated if an earlier re-run task now conflicts - case statusExecuted, statusValidated: - if valid, conflicts := s.findConflicts(tasks[i]); !valid { - s.invalidateTask(tasks[i]) - - // if the conflicts are now validated, then rerun this task - if indexesValidated(tasks, conflicts) { - res = append(res, tasks[i]) - } else { - // otherwise, wait for completion - tasks[i].Dependencies = conflicts - tasks[i].Status = statusWaiting - } - } else if len(conflicts) == 0 { - tasks[i].Status = statusValidated - } // TODO: do we need to have handling for conflicts existing here? - - case statusWaiting: - // if conflicts are done, then this task is ready to run again - if indexesValidated(tasks, tasks[i].Dependencies) { - res = append(res, tasks[i]) +func (s *scheduler) validateAll(ctx sdk.Context, tasks []*deliverTxTask) ([]*deliverTxTask, error) { + ctx, span := s.traceSpan(ctx, "SchedulerValidateAll", nil) + defer span.End() + + var mx sync.Mutex + var res []*deliverTxTask + + wg := sync.WaitGroup{} + for i := 0; i < len(tasks); i++ { + wg.Add(1) + go func(task *deliverTxTask) { + defer wg.Done() + if !s.validateTask(ctx, task) { + task.Reset() + task.Increment() + mx.Lock() + res = append(res, task) + mx.Unlock() } - } + }(tasks[i]) } + wg.Wait() + return res, nil } @@ -256,6 +292,9 @@ func (s *scheduler) validateAll(ctx sdk.Context, tasks []*deliverTxTask) ([]*del // Tasks are updated with their status // TODO: error scenarios func (s *scheduler) executeAll(ctx sdk.Context, tasks []*deliverTxTask) error { + ctx, span := s.traceSpan(ctx, "SchedulerExecuteAll", nil) + defer span.End() + ch := make(chan *deliverTxTask, len(tasks)) grp, gCtx := errgroup.WithContext(ctx.Context()) @@ -265,6 +304,15 @@ func (s *scheduler) executeAll(ctx sdk.Context, tasks []*deliverTxTask) error { workers = len(tasks) } + // validationWg waits for all validations to complete + // validations happen in separate goroutines in order to wait on previous index + validationWg := &sync.WaitGroup{} + validationWg.Add(len(tasks)) + grp.Go(func() error { + validationWg.Wait() + return nil + }) + for i := 0; i < workers; i++ { grp.Go(func() error { for { @@ -275,24 +323,16 @@ func (s *scheduler) executeAll(ctx sdk.Context, tasks []*deliverTxTask) error { if !ok { return nil } - s.executeTask(task) + s.prepareAndRunTask(validationWg, ctx, task) } } }) } - grp.Go(func() error { - defer close(ch) - for _, task := range tasks { - s.prepareTask(ctx, task) - - select { - case <-gCtx.Done(): - return gCtx.Err() - case ch <- task: - } - } - return nil - }) + + for _, task := range tasks { + ch <- task + } + close(ch) if err := grp.Wait(); err != nil { return err @@ -301,16 +341,46 @@ func (s *scheduler) executeAll(ctx sdk.Context, tasks []*deliverTxTask) error { return nil } +func (s *scheduler) prepareAndRunTask(wg *sync.WaitGroup, ctx sdk.Context, task *deliverTxTask) { + eCtx, eSpan := s.traceSpan(ctx, "SchedulerExecute", task) + defer eSpan.End() + task.Ctx = eCtx + + s.executeTask(task.Ctx, task) + go func() { + defer wg.Done() + defer close(task.ValidateCh) + // wait on previous task to finish validation + if task.Index > 0 { + <-s.allTasks[task.Index-1].ValidateCh + } + if !s.validateTask(task.Ctx, task) { + task.Reset() + } + task.ValidateCh <- struct{}{} + }() +} + +func (s *scheduler) traceSpan(ctx sdk.Context, name string, task *deliverTxTask) (sdk.Context, trace.Span) { + spanCtx, span := s.tracingInfo.StartWithContext(name, ctx.TraceSpanContext()) + if task != nil { + span.SetAttributes(attribute.String("txHash", fmt.Sprintf("%X", sha256.Sum256(task.Request.Tx)))) + span.SetAttributes(attribute.Int("txIndex", task.Index)) + span.SetAttributes(attribute.Int("txIncarnation", task.Incarnation)) + } + ctx = ctx.WithTraceSpanContext(spanCtx) + return ctx, span +} + // prepareTask initializes the context and version stores for a task func (s *scheduler) prepareTask(ctx sdk.Context, task *deliverTxTask) { - // initialize the context ctx = ctx.WithTxIndex(task.Index) + + _, span := s.traceSpan(ctx, "SchedulerPrepare", task) + defer span.End() + + // initialize the context abortCh := make(chan occ.Abort, len(s.multiVersionStores)) - spanCtx, span := s.tracingInfo.StartWithContext("SchedulerExecute", ctx.TraceSpanContext()) - span.SetAttributes(attribute.String("txHash", fmt.Sprintf("%X", sha256.Sum256(task.Request.Tx)))) - span.SetAttributes(attribute.Int("txIndex", task.Index)) - span.SetAttributes(attribute.Int("txIncarnation", task.Incarnation)) - ctx = ctx.WithTraceSpanContext(spanCtx) // if there are no stores, don't try to wrap, because there's nothing to wrap if len(s.multiVersionStores) > 0 { @@ -334,14 +404,17 @@ func (s *scheduler) prepareTask(ctx sdk.Context, task *deliverTxTask) { task.AbortCh = abortCh task.Ctx = ctx - task.Span = span } // executeTask executes a single task -func (s *scheduler) executeTask(task *deliverTxTask) { - if task.Span != nil { - defer task.Span.End() - } +func (s *scheduler) executeTask(ctx sdk.Context, task *deliverTxTask) { + + s.prepareTask(ctx, task) + + dCtx, dSpan := s.traceSpan(task.Ctx, "SchedulerDeliverTx", task) + defer dSpan.End() + task.Ctx = dCtx + resp := s.deliverTx(task.Ctx, task.Request) close(task.AbortCh) diff --git a/tasks/scheduler_test.go b/tasks/scheduler_test.go index accc8bf3e..9d24b54a8 100644 --- a/tasks/scheduler_test.go +++ b/tasks/scheduler_test.go @@ -66,7 +66,7 @@ func TestProcessAll(t *testing.T) { { name: "Test every tx accesses same key", workers: 50, - runs: 25, + runs: 50, addStores: true, requests: requestList(50), deliverTxFunc: func(ctx sdk.Context, req types.RequestDeliverTx) types.ResponseDeliverTx { @@ -94,7 +94,7 @@ func TestProcessAll(t *testing.T) { } // confirm last write made it to the parent store latest := ctx.MultiStore().GetKVStore(testStoreKey).Get(itemKey) - require.Equal(t, []byte("49"), latest) + require.Equal(t, []byte(fmt.Sprintf("%d", len(res)-1)), latest) }, expectedErr: nil, },