Skip to content

Commit

Permalink
Initial FakeContext-based context impliementation.
Browse files Browse the repository at this point in the history
Not tested, just a progress commit.
  • Loading branch information
DPJacques committed Oct 19, 2024
1 parent 7e524bd commit 0841b94
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 18 deletions.
35 changes: 19 additions & 16 deletions clockwork.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,14 @@ func (fc *FakeClock) After(d time.Duration) <-chan time.Time {
return fc.NewTimer(d).Chan()
}

// afterTime is like After, but uses a time instead of a duration.
//
// It is used to ensure FakeClock's lock is held constant through calling
// fc.After(t.Sub(fc.Now())). It should not be exposed externally.
func (fc *FakeClock) afterTime(t time.Time) <-chan time.Time {
return fc.newTimerAtTime(t, nil).Chan()
}

// Sleep blocks until the given duration has passed on the fakeClock.
func (fc *FakeClock) Sleep(d time.Duration) {
<-fc.After(d)
Expand Down Expand Up @@ -179,25 +187,20 @@ func (fc *FakeClock) AfterFunc(d time.Duration, f func()) Timer {

// newTimer returns a new timer, using an optional afterFunc.
func (fc *FakeClock) newTimer(d time.Duration, afterfunc func()) *fakeTimer {
var ft *fakeTimer
ft = &fakeTimer{
firer: newFirer(),
reset: func(d time.Duration) bool {
fc.l.Lock()
defer fc.l.Unlock()
// fc.l must be held across the calls to stopExpirer & setExpirer.
stopped := fc.stopExpirer(ft)
fc.setExpirer(ft, d)
return stopped
},
stop: func() bool { return fc.stop(ft) },

afterFunc: afterfunc,
}
ft := newFakeTimer(fc, afterfunc)
fc.set(ft, d)
return ft
}

// newTimerAtTime is like newTimer, but uses a time instead of a duration.
func (fc *FakeClock) newTimerAtTime(t time.Time, afterfunc func()) *fakeTimer {
ft := newFakeTimer(fc, afterfunc)
fc.l.Lock()
defer fc.l.Unlock()
fc.setExpirer(ft, t.Sub(fc.time))
return ft
}

// Advance advances fakeClock to a new point in time, ensuring waiters and
// blockers are notified appropriately before returning.
func (fc *FakeClock) Advance(d time.Duration) {
Expand Down Expand Up @@ -289,7 +292,7 @@ func (fc *FakeClock) stopExpirer(e expirer) bool {
return true
}

// set sets an expirer to expire at a future point in time.
// set sets an expirer to expire after a duration.
func (fc *FakeClock) set(e expirer, d time.Duration) {
fc.l.Lock()
defer fc.l.Unlock()
Expand Down
107 changes: 105 additions & 2 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@ package clockwork

import (
"context"
"errors"
"sync"
"time"
)

// contextKey is private to this package so we can ensure uniqueness here. This
// type identifies context values provided by this package.
type contextKey string

// keyClock provides a clock for injecting during tests. If absent, a real clock should be used.
// keyClock provides a clock for injecting during tests. If absent, a real clock
// should be used.
var keyClock = contextKey("clock") // clockwork.Clock

// AddToContext creates a derived context that references the specified clock.
Expand All @@ -21,10 +25,109 @@ func AddToContext(ctx context.Context, clock Clock) context.Context {
return context.WithValue(ctx, keyClock, clock)
}

// FromContext extracts a clock from the context. If not present, a real clock is returned.
// FromContext extracts a clock from the context. If not present, a real clock
// is returned.
func FromContext(ctx context.Context) Clock {
if clock, ok := ctx.Value(keyClock).(Clock); ok {
return clock
}
return NewRealClock()
}

type fakeClockContext struct {
parent context.Context
clock *FakeClock

deadline time.Time

mu sync.Mutex
done chan struct{}
err error
}

// WithDeadline returns a context with a deadline based on a [FakeClock].
//
// The returned context ignores parent cancelation if the parent was cancelled
// with a [context.DeadlineExceeded] error. Any other error returned by the
// parent is treated normally, cancelling the returned context.
//
// If the parent is cancelled with a [context.DeadlineExceeded] error, the only
// way to then cancel the returned context is by calling the returned
// context.CancelFunc.
func WithDeadline(parent context.Context, clock *FakeClock, t time.Time) (context.Context, context.CancelFunc) {
ctx := &fakeClockContext{
parent: parent,
}
cancelOnce := ctx.runCancel(clock.afterTime(t))
return &fakeClockContext{}, cancelOnce
}

// WithTimeout returns a context with a timeout based on a [FakeClock].
//
// The returned context follows the same behaviors as [WithDeadline].
func WithTimeout(parent context.Context, clock *FakeClock, d time.Duration) (context.Context, context.CancelFunc) {
ctx := &fakeClockContext{
parent: parent,
}
cancelOnce := ctx.runCancel(clock.After(d))
return &fakeClockContext{}, cancelOnce
}

func (c *fakeClockContext) setError(err error) {
c.mu.Lock()
defer c.mu.Unlock()
c.err = err
close(c.done)
}

func (c *fakeClockContext) Deadline() (time.Time, bool) {
return time.Time{}, false
}

func (c *fakeClockContext) Done() <-chan struct{} {
return c.done
}

func (c *fakeClockContext) Err() error {
c.mu.Lock()
defer c.mu.Unlock()
return c.err
}

func (c *fakeClockContext) Value(key any) any {
return c.parent.Value(key)
}

func (c *fakeClockContext) runCancel(clockExpCh <-chan time.Time) context.CancelFunc {
cancelCh := make(chan struct{})
result := sync.OnceFunc(func() {
close(cancelCh)
})

go func() {
select {
case <-clockExpCh:
c.setError(context.DeadlineExceeded)
return
case <-cancelCh:
c.setError(context.DeadlineExceeded)
return
case <-c.parent.Done():
if err := c.parent.Err(); !errors.Is(err, context.DeadlineExceeded) {
c.setError(err)
return
}
}

// The parent context has hit its deadline, but because we are using a fake
// clock we ignore it.
select {
case <-clockExpCh:
c.setError(context.DeadlineExceeded)
case <-cancelCh:
c.setError(context.DeadlineExceeded)
}
}()

return result
}
19 changes: 19 additions & 0 deletions timer.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,25 @@ type fakeTimer struct {
afterFunc func()
}

func newFakeTimer(fc *FakeClock, afterfunc func()) *fakeTimer {
var ft *fakeTimer
ft = &fakeTimer{
firer: newFirer(),
reset: func(d time.Duration) bool {
fc.l.Lock()
defer fc.l.Unlock()
// fc.l must be held across the calls to stopExpirer & setExpirer.
stopped := fc.stopExpirer(ft)
fc.setExpirer(ft, d)
return stopped
},
stop: func() bool { return fc.stop(ft) },

afterFunc: afterfunc,
}
return ft
}

func (f *fakeTimer) Reset(d time.Duration) bool {
return f.reset(d)
}
Expand Down

0 comments on commit 0841b94

Please sign in to comment.