Skip to content

Commit

Permalink
all: close then re-dial client conn on context deadline
Browse files Browse the repository at this point in the history
The current handling of context deadlines can lead to data ordering
problems, in a racey manner. For example, if command is cancelled after
the write succeeds, but before the response is read, subsequent commands
may read the previous response data. This can be partially helped by
always calling SetDeadline() instead of the Read/Write variants, but
this is still racey. E.g., the server may initiate the read after we
have initiated a new command and have cleared the deadline again.

Instead, when a context is cancelled, call Close() on the underlying
connection. Then, when another command request is made, re-dial the
connection to give ourselves a clean slate.
  • Loading branch information
enr0n committed Oct 1, 2024
1 parent 668206e commit 69f6b2f
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 25 deletions.
66 changes: 54 additions & 12 deletions vici/client_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (
"fmt"
"io"
"net"
"time"
)

const (
Expand All @@ -44,19 +43,56 @@ var (
)

type clientConn struct {
conn net.Conn
network string
addr string
dialer func(ctx context.Context, network, addr string) (net.Conn, error)

closed bool
conn net.Conn
}

func (cc *clientConn) dial(ctx context.Context) error {
if !cc.closed && cc.conn != nil {
return nil
}

conn, err := cc.dialer(ctx, cc.network, cc.addr)
if err != nil {
return err
}

cc.conn = conn
cc.closed = false

return nil
}

func (cc *clientConn) Close() error {
if cc.closed || cc.conn == nil {
return nil
}

cc.closed = true

return cc.conn.Close()
}

func (cc *clientConn) packetWrite(ctx context.Context, m *Message) error {
if err := cc.conn.SetWriteDeadline(time.Time{}); err != nil {
if err := cc.dial(ctx); err != nil {
return err
}

rc := cc.asyncPacketWrite(m)
select {
case <-ctx.Done():
err := cc.conn.SetWriteDeadline(time.Now())
return errors.Join(err, ctx.Err())
case err := <-cc.awaitPacketWrite(m):
// Disconnect on context deadline to avoid data ordering
// problems with subsequent read/writes. Re-establish the
// connection later.
cc.Close()
<-rc

return ctx.Err()
case err := <-rc:
if err != nil {
return err
}
Expand All @@ -65,15 +101,21 @@ func (cc *clientConn) packetWrite(ctx context.Context, m *Message) error {
}

func (cc *clientConn) packetRead(ctx context.Context) (*Message, error) {
if err := cc.conn.SetReadDeadline(time.Time{}); err != nil {
if err := cc.dial(ctx); err != nil {
return nil, err
}

rc := cc.asyncPacketRead()
select {
case <-ctx.Done():
err := cc.conn.SetReadDeadline(time.Now())
return nil, errors.Join(err, ctx.Err())
case v := <-cc.awaitPacketRead():
// Disconnect on context deadline to avoid data ordering
// problems with subsequent read/writes. Re-establish the
// connection later.
cc.Close()
<-rc

return nil, ctx.Err()
case v := <-rc:
switch v.(type) {

Check failure on line 119 in vici/client_conn.go

View workflow job for this annotation

GitHub Actions / lint

typeSwitchVar: 2 cases can benefit from type switch with assignment (gocritic)
case error:
return nil, v.(error)

Check failure on line 121 in vici/client_conn.go

View workflow job for this annotation

GitHub Actions / lint

S1034(related information): could eliminate this type assertion (gosimple)
Expand All @@ -86,7 +128,7 @@ func (cc *clientConn) packetRead(ctx context.Context) (*Message, error) {
}
}

func (cc *clientConn) awaitPacketWrite(m *Message) <-chan error {
func (cc *clientConn) asyncPacketWrite(m *Message) <-chan error {
r := make(chan error, 1)
buf := bytes.NewBuffer([]byte{})

Expand Down Expand Up @@ -117,7 +159,7 @@ func (cc *clientConn) awaitPacketWrite(m *Message) <-chan error {
return r
}

func (cc *clientConn) awaitPacketRead() <-chan any {
func (cc *clientConn) asyncPacketRead() <-chan any {
r := make(chan any, 1)

go func() {
Expand Down
32 changes: 26 additions & 6 deletions vici/client_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func TestPacketRead(t *testing.T) {
<-done
}

func TestPacketWriteContext(t *testing.T) {
func TestPacketWriteContextCancel(t *testing.T) {
client, srvr := net.Pipe()
defer client.Close()
defer srvr.Close()
Expand All @@ -144,16 +144,26 @@ func TestPacketWriteContext(t *testing.T) {
t.Fatalf("Expected cancel on packet write, but got %v", err)
}

ctx, cancel = context.WithTimeout(context.Background(), 3*time.Second)
}

Check failure on line 147 in vici/client_conn_test.go

View workflow job for this annotation

GitHub Actions / lint

unnecessary trailing newline (whitespace)
func TestPacketWriteContextTimeout(t *testing.T) {
client, srvr := net.Pipe()
defer client.Close()
defer srvr.Close()

cc := &clientConn{
conn: client,
}

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()

err = cc.packetWrite(ctx, goldNamedPacket)
err := cc.packetWrite(ctx, goldNamedPacket)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("Expected timeout on packet write, but got %v", err)
}
}

func TestPacketReadContext(t *testing.T) {
func TestPacketReadContextCancel(t *testing.T) {
client, srvr := net.Pipe()
defer client.Close()
defer srvr.Close()
Expand All @@ -169,11 +179,21 @@ func TestPacketReadContext(t *testing.T) {
if !errors.Is(err, context.Canceled) {
t.Fatalf("Expected cancel on packet read, but got %v", err)
}
}

func TestPacketReadContextTimeout(t *testing.T) {
client, srvr := net.Pipe()
defer client.Close()
defer srvr.Close()

ctx, cancel = context.WithTimeout(context.Background(), 3*time.Second)
cc := &clientConn{
conn: client,
}

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()

_, err = cc.packetRead(ctx)
_, err := cc.packetRead(ctx)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("Expected timeout on packet read, but got %v", err)
}
Expand Down
5 changes: 4 additions & 1 deletion vici/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ func (el *eventListener) Close() error {
return err
}

el.cc.conn.Close()
if el.cc != nil {
el.cc.Close()
el.cc = nil
}

return nil
}
Expand Down
14 changes: 8 additions & 6 deletions vici/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,15 @@ func (s *Session) newClientConn() (*clientConn, error) {
return &clientConn{conn: s.conn}, nil
}

conn, err := s.dialer(context.Background(), s.network, s.addr)
if err != nil {
return nil, err
cc := &clientConn{
network: s.network,
addr: s.addr,
dialer: s.dialer,
conn: nil,
}

cc := &clientConn{
conn: conn,
if err := cc.dial(context.Background()); err != nil {
return nil, err
}

return cc, nil
Expand All @@ -107,7 +109,7 @@ func (s *Session) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.cc != nil {
if err := s.cc.conn.Close(); err != nil {
if err := s.cc.Close(); err != nil {
return err
}

Expand Down

0 comments on commit 69f6b2f

Please sign in to comment.