Skip to content

Commit

Permalink
Add Expirations() and BlockUntilContext(...) to FakeClock.
Browse files Browse the repository at this point in the history
- Expirations() allows callers to validate that AfterFunc was not called.
- BlockUntilContext(...) was added to fakeClock previously but never exposed in
  the interface. Oops.
- Other NITs: documentation, spellin'

We choose to add both functions in the same commit becuase any change to the
interface requires toilsome updates by downstream users. Might as well do both
functions at once.
  • Loading branch information
DPJacques committed Jul 16, 2023
1 parent 8a29bc1 commit b176db8
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 13 deletions.
61 changes: 49 additions & 12 deletions clockwork.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// Package clockwork contains a simple fake clock for Go.
package clockwork

import (
Expand Down Expand Up @@ -28,10 +29,38 @@ type Clock interface {
// expected number of waiters.
type FakeClock interface {
Clock
// Advance advances the FakeClock to a new point in time, ensuring any existing
// waiters are notified appropriately before returning.

// Advance advances the FakeClock to a new point in time, ensuring the expiration count is updated
// and any existing waiters are notified appropriately before returning.
Advance(d time.Duration)

// Expirations returns the total number of expirations over the lifetime of the clock.
//
// The return value only increments during calls to Advance() allowing callers to, among other
// things, synchronously validate that a function provided to AfterFunc was not called.
//
// Expirations() increments when any of the following occur:
// - A Timer expires.
// - A value is sent on a channel returned by Ticker.Chan(). This happens regardless of whether
// the value is received by a caller. I.e. ticks that are dropped to make up for slow receivers
// still cause this value to increment. For details, see documentation on time.NewTicker.
// - A valure is sent on a channel returned by After.
// - A function provided to AfterFunc is called. Note this increments before the goroutine
// starts, so there is no race condition.
//
// The successful stopping of a Ticker or Timer, including Timers returned by AfterFunc, do not
// increment Expirations().
Expirations() int

// BlockUntilContext blocks until the fakeClock has the given number of waiters or the context is
// cancelled.
BlockUntilContext(ctx context.Context, n int) error

// BlockUntil blocks until the FakeClock has the given number of waiters.
//
// Prefer BlockUntilContext in new code, which offers context cancellation to prevent deadlock.
//
// Deprecated: New code should prefer BlockUntilContext.
BlockUntil(waiters int)
}

Expand Down Expand Up @@ -90,10 +119,11 @@ func (rc *realClock) AfterFunc(d time.Duration, f func()) Timer {
type fakeClock struct {
// l protects all attributes of the clock, including all attributes of all
// waiters and blockers.
l sync.RWMutex
waiters []expirer
blockers []*blocker
time time.Time
l sync.RWMutex
waiters []expirer
expirations int
blockers []*blocker
time time.Time
}

// blocker is a caller of BlockUntil.
Expand Down Expand Up @@ -202,11 +232,12 @@ func (fc *fakeClock) Advance(d time.Duration) {
w := fc.waiters[0]
fc.waiters = fc.waiters[1:]

// Use the waiter's expriation as the current time for this expiration.
// Use the waiter's expiration as the current time for this expiration.
now := w.expiry()
fc.time = now
fc.expirations++
if d := w.expire(now); d != nil {
// Set the new exipration if needed.
// Set the new expiration if needed.
fc.setExpirer(w, *d)
}
}
Expand All @@ -215,10 +246,10 @@ func (fc *fakeClock) Advance(d time.Duration) {

// BlockUntil blocks until the fakeClock has the given number of waiters.
//
// Prefer BlockUntilContext, which offers context cancellation to prevent
// deadlock.
// Prefer BlockUntilContext in new code, which offers context cancellation to
// prevent deadlock.
//
// Deprecation warning: This function might be deprecated in later versions.
// Deprecated: New code should prefer BlockUntilContext.
func (fc *fakeClock) BlockUntil(n int) {
b := fc.newBlocker(n)
if b == nil {
Expand All @@ -243,6 +274,12 @@ func (fc *fakeClock) BlockUntilContext(ctx context.Context, n int) error {
}
}

func (fc *fakeClock) Expirations() int {
fc.l.Lock()
defer fc.l.Unlock()
return fc.expirations
}

func (fc *fakeClock) newBlocker(n int) *blocker {
fc.l.Lock()
defer fc.l.Unlock()
Expand Down Expand Up @@ -307,7 +344,7 @@ func (fc *fakeClock) setExpirer(e expirer, d time.Duration) {
return fc.waiters[i].expiry().Before(fc.waiters[j].expiry())
})

// Notify blockers of our new waiter.
// Notify blockers of our new waiter.
var blocked []*blocker
count := len(fc.waiters)
for _, b := range fc.blockers {
Expand Down
75 changes: 75 additions & 0 deletions clockwork_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,78 @@ func TestFakeClockRace(t *testing.T) {
go func() { fc.NewTimer(d) }()
go func() { fc.Sleep(d) }()
}

func TestExpirations(t *testing.T) {
t.Parallel()

t.Run("AfterFunc increments counter", func(t *testing.T) {
fc := &fakeClock{}
fc.AfterFunc(time.Minute, func() {})
fc.AfterFunc(2*time.Minute, func() {})
fc.AfterFunc(2*time.Minute, func() {})
fc.AfterFunc(3*time.Minute, func() {})

start := fc.Now()

fc.Advance(time.Minute)
want := 1
if got := fc.Expirations(); got != want {
t.Errorf("after %v, fc.Expirations() = %v, want %v", fc.Since(start), got, want)
}

fc.Advance(time.Minute)
want = 3
if got := fc.Expirations(); got != want {
t.Errorf("after %v, fc.Expirations() = %v, want %v", fc.Since(start), got, want)
}

fc.Advance(30 * time.Second) // should not cause expirations.
if got := fc.Expirations(); got != want {
t.Errorf("after %v, fc.Expirations() = %v, want %v", fc.Since(start), got, want)
}
})

t.Run("Calls to Stop do not increment counter", func(t *testing.T) {
fc := &fakeClock{}
ticker := fc.NewTicker(time.Minute)
timer := fc.NewTimer(time.Minute)

start := fc.Now()

// Advance a little for good measure, but should have no effect.
fc.Advance(30 * time.Second)
want := 0
if got := fc.Expirations(); got != want {
t.Errorf("after %v, fc.Expirations() = %v, want %v", fc.Since(start), got, want)
}

timer.Stop()
ticker.Stop()
fc.Advance(time.Minute) // advances past the set expirations
if got := fc.Expirations(); got != want {
t.Errorf("after %v, fc.Expirations() = %v, want %v", fc.Since(start), got, want)
}
})

t.Run("Dropped ticks increment counter", func(t *testing.T) {
fc := &fakeClock{}
ticker := fc.NewTicker(time.Minute)

fc.Advance(2 * time.Minute)
want := 2
if got := fc.Expirations(); got != want {
t.Errorf("fc.Expirations() = %v, want %v", got, want)
}
// As of this writing I am using a variable for ticker because I don't know
// if assigning it to _ makes it eligible for garbage collection.
//
// Since we have to use the variable to appease the compiler, make sure we
// can receive on the ticker channel.
select {
case <-ticker.Chan(): //
default:
t.Errorf("Ticker should have fired at least once.")
}

})
}
2 changes: 1 addition & 1 deletion timer.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func (r realTimer) Chan() <-chan time.Time {
type fakeTimer struct {
firer

// reset and stop provide the implmenetation of the respective exported
// reset and stop provide the implementation of the respective exported
// functions.
reset func(d time.Duration) bool
stop func() bool
Expand Down

0 comments on commit b176db8

Please sign in to comment.