Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use assertion function #7433

Draft
wants to merge 1 commit into
base: priority
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 14 additions & 19 deletions service/matching/matcher_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import (
"time"

enumsspb "go.temporal.io/server/api/enums/v1"
"go.temporal.io/server/common/log"
"go.temporal.io/server/common/softassert"
"go.temporal.io/server/common/util"
)

Expand Down Expand Up @@ -232,6 +234,7 @@ func (t *taskPQ) ForEachTask(pred func(*internalTask) bool, post func(*internalT

type matcherData struct {
config *taskQueueConfig
logger log.Logger

lock sync.Mutex // covers everything below, and all fields in any waitableMatchResult

Expand All @@ -246,12 +249,13 @@ type matcherData struct {
lastPoller time.Time // most recent poll start time
}

func newMatcherData(config *taskQueueConfig) matcherData {
func newMatcherData(config *taskQueueConfig, logger log.Logger) matcherData {
return matcherData{
config: config,
tasks: taskPQ{
ages: newBacklogAgeTracker(),
},
logger: logger,
}
}

Expand Down Expand Up @@ -293,7 +297,7 @@ func (d *matcherData) EnqueueTaskAndWait(ctxs []context.Context, task *internalT

if task.matchResult == nil {
d.tasks.Remove(task)
task.wake(&matchResult{ctxErr: ctx.Err(), ctxErrIdx: i})
task.wake(d.logger, &matchResult{ctxErr: ctx.Err(), ctxErrIdx: i})
}
})
defer stop() // nolint:revive // there's only ever a small number of contexts
Expand Down Expand Up @@ -341,7 +345,7 @@ func (d *matcherData) EnqueuePollerAndWait(ctxs []context.Context, poller *waiti
if poller.matchHeapIndex >= 0 {
d.pollers.Remove(poller)
}
poller.wake(&matchResult{ctxErr: ctx.Err(), ctxErrIdx: i})
poller.wake(d.logger, &matchResult{ctxErr: ctx.Err(), ctxErrIdx: i})
}
})
defer stop() // nolint:revive // there's only ever a small number of contexts
Expand Down Expand Up @@ -384,7 +388,7 @@ func (d *matcherData) ReprocessTasks(pred func(*internalTask) bool) []*internalT
func(task *internalTask) {
// for sync tasks: wake up waiters with a fake context error
// for backlog tasks: the caller should call finish()
task.wake(&matchResult{ctxErr: errReprocessTask, ctxErrIdx: -1})
task.wake(d.logger, &matchResult{ctxErr: errReprocessTask, ctxErrIdx: -1})
reprocess = append(reprocess, task)
},
)
Expand Down Expand Up @@ -486,10 +490,10 @@ func (d *matcherData) findAndWakeMatches() {
task.recycleToken = d.recycleToken

res := &matchResult{task: task, poller: poller}
task.wake(res)
task.wake(d.logger, res)
// for poll forwarder: skip waking poller, forwarder will call finishMatchAfterPollForward
if !task.isPollForwarder {
poller.wake(res)
poller.wake(d.logger, res)
}
// TODO(pri): consider having task forwarding work the same way, with a half-match,
// instead of full match and then pass forward result on response channel?
Expand Down Expand Up @@ -519,7 +523,7 @@ func (d *matcherData) finishMatchAfterPollForward(poller *waitingPoller, task *i
defer d.lock.Unlock()

if poller.matchResult == nil {
poller.wake(&matchResult{task: task, poller: poller})
poller.wake(d.logger, &matchResult{task: task, poller: poller})
}
}

Expand Down Expand Up @@ -552,9 +556,9 @@ func (w *waitableMatchResult) initMatch(d *matcherData) {
// call with matcherData.lock held.
// w.matchResult must be nil (can't call wake twice).
// w must not be in queues anymore.
func (w *waitableMatchResult) wake(res *matchResult) {
bugIf(w.matchResult != nil, "bug: wake called twice")
bugIf(w.matchHeapIndex >= 0, "bug: wake called but still in heap")
func (w *waitableMatchResult) wake(logger log.Logger, res *matchResult) {
softassert.That(logger, w.matchResult == nil, "wake called twice")
softassert.That(logger, w.matchHeapIndex < 0, "wake called but still in heap")
w.matchResult = res
w.matchCond.Signal()
}
Expand Down Expand Up @@ -628,12 +632,3 @@ func (s *simpleLimiter) consume(now int64, tokens int64) {
// burst from now and adding one interval.
s.ready = max(now, s.ready+s.burst.Nanoseconds()) - s.burst.Nanoseconds() + tokens*s.interval.Nanoseconds()
}

// simple assertions
// TODO(pri): replace by something that doesn't panic

func bugIf(cond bool, msg string) {
if cond {
panic(msg)
}
}
10 changes: 9 additions & 1 deletion service/matching/physical_task_queue_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,15 @@ func newPhysicalTaskQueueManager(
return nil, err
}
}
pqMgr.priMatcher = newPriTaskMatcher(tqCtx, config, queue.partition, fwdr, pqMgr.taskValidator, pqMgr.metricsHandler)
pqMgr.priMatcher = newPriTaskMatcher(
tqCtx,
config,
queue.partition,
fwdr,
pqMgr.taskValidator,
logger,
pqMgr.metricsHandler,
)
pqMgr.matcher = pqMgr.priMatcher
} else {
pqMgr.backlogMgr = newBacklogManager(
Expand Down
31 changes: 19 additions & 12 deletions service/matching/pri_matcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ import (
"go.temporal.io/server/common"
"go.temporal.io/server/common/backoff"
"go.temporal.io/server/common/clock"
"go.temporal.io/server/common/log"
"go.temporal.io/server/common/metrics"
"go.temporal.io/server/common/primitives/timestamp"
"go.temporal.io/server/common/quotas"
"go.temporal.io/server/common/softassert"
"go.temporal.io/server/common/tqid"
"go.temporal.io/server/common/util"
)
Expand All @@ -59,7 +61,8 @@ type priTaskMatcher struct {
fwdr *priForwarder
validator taskValidator
metricsHandler metrics.Handler // namespace metric scope
numPartitions func() int // number of task queue partitions
logger log.Logger
numPartitions func() int // number of task queue partitions

limiterLock sync.Mutex
adminNsRate float64
Expand Down Expand Up @@ -118,12 +121,14 @@ func newPriTaskMatcher(
partition tqid.Partition,
fwdr *priForwarder,
validator taskValidator,
logger log.Logger,
metricsHandler metrics.Handler,
) *priTaskMatcher {
tm := &priTaskMatcher{
config: config,
data: newMatcherData(config),
data: newMatcherData(config, logger),
tqCtx: tqCtx,
logger: logger,
metricsHandler: metricsHandler,
partition: partition,
fwdr: fwdr,
Expand Down Expand Up @@ -178,7 +183,7 @@ func (tm *priTaskMatcher) forwardTasks(lim quotas.RateLimiter, retrier backoff.R
if res.ctxErr != nil {
return // task queue closing
}
bugIf(res.task == nil, "bug: bad match result in forwardTasks")
softassert.That(tm.logger, res.task != nil, "expected a task from match")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should continue or return on failure, otherwise it'll just panic immediately


err := tm.forwardTask(res.task)

Expand Down Expand Up @@ -247,11 +252,13 @@ func (tm *priTaskMatcher) validateTasksOnRoot(lim quotas.RateLimiter, retrier ba
if res.ctxErr != nil {
return // task queue closing
}
bugIf(res.task == nil, "bug: bad match result in validateTasksOnRoot")
softassert.That(tm.logger, res.task != nil, "expected a task from match")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this too


task := res.task
bugIf(task.forwardCtx != nil || task.isSyncMatchTask() || task.source != enumsspb.TASK_SOURCE_DB_BACKLOG,
"bug: validator got a sync task")
softassert.That(tm.logger, task.forwardCtx == nil, "expected non-forwarded task")
softassert.That(tm.logger, !task.isSyncMatchTask(), "expected non-sync match task")
softassert.That(tm.logger, task.source == enumsspb.TASK_SOURCE_DB_BACKLOG, "expected backlog task")

maybeValid := tm.validator == nil || tm.validator.maybeValidate(task.event.AllocatedTaskInfo, tm.partition.TaskType())
if !maybeValid {
// We found an invalid one, complete it and go back for another immediately.
Expand All @@ -276,7 +283,7 @@ func (tm *priTaskMatcher) forwardPolls() {
if res.ctxErr != nil {
return // task queue closing
}
bugIf(res.poller == nil, "bug: bad match result in forwardPolls")
softassert.That(tm.logger, res.poller != nil, "expected a poller from match")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs to continue or return


poller := res.poller
// We need to use the real source poller context since it has the poller id and
Expand Down Expand Up @@ -324,7 +331,7 @@ func (tm *priTaskMatcher) forwardPolls() {
func (tm *priTaskMatcher) Offer(ctx context.Context, task *internalTask) (bool, error) {
finish := func() (bool, error) {
res, ok := task.getResponse()
bugIf(!ok, "Offer must be given a sync match task")
softassert.That(tm.logger, ok, "expected a sync match task")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs to return error?

if res.forwarded {
if res.forwardErr == nil {
// task was remotely sync matched on the parent partition
Expand Down Expand Up @@ -364,7 +371,7 @@ func (tm *priTaskMatcher) Offer(ctx context.Context, task *internalTask) (bool,
if res.ctxErr != nil {
return false, res.ctxErr
}
bugIf(res.poller == nil, "bug: bad match result in Offer")
softassert.That(tm.logger, res.poller != nil, "expeced poller from match")
return finish()
}

Expand Down Expand Up @@ -403,9 +410,9 @@ again:
}
return nil, res.ctxErr
}
bugIf(res.poller == nil, "bug: bad match result in syncOfferTask")
softassert.That(tm.logger, res.poller != nil, "expected poller from match")
response, ok := task.getResponse()
bugIf(!ok, "OfferQuery/OfferNexusTask must be given a sync match task")
softassert.That(tm.logger, ok, "expected a sync match task")
// Note: if task was not forwarded, this will just be the zero value and nil.
// That's intended: the query/nexus handler in matchingEngine will wait for the real
// result separately.
Expand Down Expand Up @@ -553,7 +560,7 @@ func (tm *priTaskMatcher) poll(
}
return nil, errNoTasks
}
bugIf(res.task == nil, "bug: bad match result in poll")
softassert.That(tm.logger, res.task != nil, "expected task from match")

task := res.task
pollWasForwarded = task.isStarted()
Expand Down
10 changes: 7 additions & 3 deletions service/matching/pri_task_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ import (
"go.temporal.io/server/common"
"go.temporal.io/server/common/backoff"
"go.temporal.io/server/common/clock"
"go.temporal.io/server/common/log"
"go.temporal.io/server/common/log/tag"
"go.temporal.io/server/common/metrics"
"go.temporal.io/server/common/persistence"
serviceerrors "go.temporal.io/server/common/serviceerror"
"go.temporal.io/server/common/softassert"
"go.temporal.io/server/common/util"
"golang.org/x/sync/semaphore"
)
Expand All @@ -63,6 +65,7 @@ type (
backlogMgr *priBacklogManagerImpl
subqueue int
notifyC chan struct{} // Used as signal to notify pump of new tasks
logger log.Logger

lock sync.Mutex

Expand Down Expand Up @@ -98,6 +101,7 @@ func newPriTaskReader(
backlogMgr: backlogMgr,
subqueue: subqueue,
notifyC: make(chan struct{}, 1),
logger: backlogMgr.logger,
retrier: backoff.NewRetrier(
common.CreateReadTaskRetryPolicy(),
clock.NewRealTimeSource(),
Expand Down Expand Up @@ -410,7 +414,7 @@ func (tr *priTaskReader) signalNewTasks(resp subqueueCreateTasksResponse) {
// adding these tasks to outstandingTasks. So they should definitely not be there.
for _, t := range resp.tasks {
_, found := tr.outstandingTasks.Get(t.TaskId)
bugIf(found, "bug: newly-written task already present in outstanding tasks")
softassert.That(tr.logger, !found, "newly-written task already present in outstanding tasks")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should probably do a slices.DeleteFunc and remove the ones that are found

}

tr.recordNewTasksLocked(resp.tasks)
Expand Down Expand Up @@ -445,8 +449,8 @@ func (tr *priTaskReader) getLoadedTasks() int {

func (tr *priTaskReader) ackTaskLocked(taskId int64) int64 {
wasAlreadyAcked, found := tr.outstandingTasks.Get(taskId)
bugIf(!found, "bug: completed task not found in outstandingTasks")
bugIf(wasAlreadyAcked.(bool), "bug: completed task was already acked")
softassert.That(tr.logger, found, "completed task not found in oustandingTasks")
softassert.That(tr.logger, !wasAlreadyAcked.(bool), "completed task was already acked")
Comment on lines +452 to +453
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably just return if either of these are true?


tr.outstandingTasks.Put(taskId, true)
tr.loadedTasks--
Expand Down
Loading