Skip to content

Commit

Permalink
timer: Add a NewTimer method to Clock
Browse files Browse the repository at this point in the history
The select statement/block is a major strength of Go, but it provides a
few challenges for the fake clock. In particular, we need a way to know
when the channel has been extracted, and ideally when it's in use.

To support tracking how many timers have their channels extracted, we
return a pointer to a channel that lives on an anonymous struct and
leverage runtime.SetFinalizer to decrement a reference-count when it
would get GC'd (likely because the select block was exited).

However, finalizers may be run a bit earlier than one would otherwise
expect (the documentation for [runtime.SetFinalizer] indicates
instruction/statement-level granularity on usage -- hence the existence
of [runtime.KeepAlive])

Due to the laxness of the guarantees from runtime.SetFinalizer in the
absense of caller-help with runtime.KeepaAlive calls, we don't expose
the finalizer-based accounting, and leave those unexported.

We can decide later whether to remove the finalizer-based accounting.
I'd like to get some mileage with it before deciding whether it would
even be useful.

On the bright side, the call to get the channel gives us a signal as to
when we're in the select block, and facilitates the
AwaitAggExtractedChans, and NumAggExtractedChans methods.

[runtime.SetFinalizer]: https://pkg.go.dev/runtime#SetFinalizer
[runtime.SetFinalizer]: https://pkg.go.dev/runtime#KeepAlive
  • Loading branch information
dfinkel committed Dec 11, 2023
1 parent 3d6b352 commit 350b3a2
Show file tree
Hide file tree
Showing 6 changed files with 580 additions and 7 deletions.
15 changes: 15 additions & 0 deletions clock.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ func (c defaultClock) ContextWithTimeout(ctx context.Context, d time.Duration) (
return context.WithTimeout(ctx, d)
}

func (c defaultClock) NewTimer(d time.Duration) Timer {
t := time.NewTimer(d)
return &defaultTimer{Timer: t}
}

// DefaultClock returns a clock that minimally wraps the `time` package
func DefaultClock() Clock {
return defaultClock{}
Expand Down Expand Up @@ -103,4 +108,14 @@ type Clock interface {
// uses the clock to determine the when the timeout has elapsed. Cause is
// ignored in Go 1.20 and earlier.
ContextWithTimeoutCause(ctx context.Context, d time.Duration, cause error) (context.Context, context.CancelFunc)

// NewTimer returns a Timer implementation which will fire after at
// least the specified duration [d]. The Ch() method returns a channel,
// and should be called inline with the receive or select case.
//
// Timers are most useful in select/case blocks. For simple cases,
// SleepFor should be preferred.
//
// Stop() is inherently racy. Be wary of the return value.
NewTimer(d time.Duration) Timer
}
96 changes: 89 additions & 7 deletions fake/fake_clock.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
// testing and skipping through timestamps without having to actually sleep in
// the test.
type Clock struct {
mu sync.Mutex
current time.Time
// sleepers contains a map from a channel on which that
// sleeper is sleeping to a target-time. When time is advanced past a
Expand All @@ -28,9 +27,8 @@ type Clock struct {
// protection necessary).
cbsWG sync.WaitGroup

// cond is broadcasted() upon any sleep or wakeup event (mutations to
// sleepers or cbs).
cond sync.Cond
// timer tracker
timerTrack timerTracker

// counter tracking the number of wakeups (protected by mu).
wakeups int
Expand All @@ -51,6 +49,21 @@ type Clock struct {
// counter tracking the number of callbacks that have ever been
// registered (via AfterFunc) (protected by mu).
callbacksAggregate int

// counter tracking the number of extracted channels (protected by mu).
extractedChans int

// counter tracking the aggregate number of extracted channels (protected by mu).
extractedChansAggregate int

// counter tracking the number of number of aggregate signaled timer channels
signaledChans int

// cond is broadcasted() upon any sleep or wakeup event (mutations to
// sleepers or cbs).
cond sync.Cond

mu sync.Mutex
}

var _ clocks.Clock = (*Clock)(nil)
Expand All @@ -62,7 +75,11 @@ func NewClock(initialTime time.Time) *Clock {
sleepers: map[chan<- struct{}]time.Time{},
cbs: map[*stopTimer]time.Time{},
cond: sync.Cond{},
timerTrack: timerTracker{
timers: map[*fakeTimer]time.Time{},
},
}
fc.timerTrack.fc = &fc
fc.cond.L = &fc.mu
return &fc
}
Expand All @@ -77,6 +94,10 @@ func (f *Clock) setClockLocked(t time.Time, cbRunningWG *sync.WaitGroup) int {
awoken++
}
}

timerWakeRes := f.timerTrack.wakeup(t)
f.signaledChans += timerWakeRes.notified

cbsRun := 0
for s, target := range f.cbs {
if target.Sub(t) <= 0 {
Expand All @@ -95,7 +116,7 @@ func (f *Clock) setClockLocked(t time.Time, cbRunningWG *sync.WaitGroup) int {
f.callbackExecs += cbsRun
f.current = t
f.cond.Broadcast()
return awoken + cbsRun
return awoken + cbsRun + timerWakeRes.awoken
}

// SetClock skips the FakeClock to the specified time (forward or backwards) The
Expand Down Expand Up @@ -344,6 +365,22 @@ func (f *Clock) AfterFunc(d time.Duration, cb func()) clocks.StopTimer {
return s
}

// NewTimer creates a new Timer
func (f *Clock) NewTimer(d time.Duration) clocks.Timer {
target := f.Now().Add(d)
// Capacity 1 so sending never blocks
ch := make(chan time.Time, 1)

ft := fakeTimer{
ch: ch,
tracker: &f.timerTrack,
}

f.timerTrack.registerTimer(&ft, target)

return &ft
}

// NumCallbackExecs returns the number of registered callbacks that have been
// executed due to time advancement.
func (f *Clock) NumCallbackExecs() int {
Expand Down Expand Up @@ -396,8 +433,8 @@ func (f *Clock) AwaitRegisteredCallbacks(n int) {
}
}

// AwaitTimerAborts waits until the aggregate number of registered callbacks
// (via AfterFunc) exceeds its argument.
// AwaitTimerAborts waits until the aggregate number of aborted callbacks
// (via AfterFunc) or timers exceeds its argument.
func (f *Clock) AwaitTimerAborts(n int) {
f.mu.Lock()
defer f.mu.Unlock()
Expand All @@ -406,6 +443,51 @@ func (f *Clock) AwaitTimerAborts(n int) {
}
}

// AwaitAggExtractedChans waits the aggregate number of calls to Ch() on
// timers to equal or exceed its argument.
// To be be most useful, uses of the channel should directly call `.Ch()` on
// the timers and dereferencing the channel pointer.
func (f *Clock) AwaitAggExtractedChans(n int) {
f.mu.Lock()
defer f.mu.Unlock()
for f.extractedChansAggregate < n {
f.cond.Wait()
}
}

// NumAggExtractedChans returns the aggregate number of calls to Ch() on
// timers.
// To be be most useful, uses of the channel should directly call `.Ch()` on
// the timers and dereferencing the channel pointer.
func (f *Clock) NumAggExtractedChans() int {
f.mu.Lock()
defer f.mu.Unlock()
return f.extractedChansAggregate
}

// numExtractedChans returns the aggregate number of calls to Ch() on
// timers.
func (f *Clock) numExtractedChans() int {
f.mu.Lock()
defer f.mu.Unlock()
return f.extractedChans
}

// awaitExtractedChans waits the number of calls to Ch() on
// timers to equal or exceed its argument.
func (f *Clock) awaitExtractedChans(n int) {
f.mu.Lock()
defer f.mu.Unlock()
for f.extractedChans < n {
f.cond.Wait()
}
}

// RegisteredTimers returns the execution-times of registered timers.
func (f *Clock) RegisteredTimers() []time.Time {
return f.timerTrack.registeredTimers()
}

// WaitAfterFuncs blocks until all currently running AfterFunc callbacks
// return.
func (f *Clock) WaitAfterFuncs() {
Expand Down
Loading

0 comments on commit 350b3a2

Please sign in to comment.