Skip to content

Commit

Permalink
message: implement "packet" directly in Message type
Browse files Browse the repository at this point in the history
This really does not need to be its own type: it's really just a message
header. Absorb this functionality into the Message type directly to
consolidate.
  • Loading branch information
enr0n committed Aug 27, 2024
1 parent 2ae9d00 commit 668206e
Show file tree
Hide file tree
Showing 8 changed files with 365 additions and 422 deletions.
44 changes: 26 additions & 18 deletions vici/client_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"time"
Expand All @@ -46,7 +47,7 @@ type clientConn struct {
conn net.Conn
}

func (cc *clientConn) packetWrite(ctx context.Context, p *packet) error {
func (cc *clientConn) packetWrite(ctx context.Context, m *Message) error {
if err := cc.conn.SetWriteDeadline(time.Time{}); err != nil {
return err
}
Expand All @@ -55,15 +56,15 @@ func (cc *clientConn) packetWrite(ctx context.Context, p *packet) error {
case <-ctx.Done():
err := cc.conn.SetWriteDeadline(time.Now())
return errors.Join(err, ctx.Err())
case err := <-cc.awaitPacketWrite(p):
case err := <-cc.awaitPacketWrite(m):
if err != nil {
return err
}
return nil
}
}

func (cc *clientConn) packetRead(ctx context.Context) (*packet, error) {
func (cc *clientConn) packetRead(ctx context.Context) (*Message, error) {
if err := cc.conn.SetReadDeadline(time.Time{}); err != nil {
return nil, err
}
Expand All @@ -72,21 +73,26 @@ func (cc *clientConn) packetRead(ctx context.Context) (*packet, error) {
case <-ctx.Done():
err := cc.conn.SetReadDeadline(time.Now())
return nil, errors.Join(err, ctx.Err())
case p := <-cc.awaitPacketRead():
if p.err != nil {
return nil, p.err
case v := <-cc.awaitPacketRead():
switch v.(type) {

Check failure on line 77 in vici/client_conn.go

View workflow job for this annotation

GitHub Actions / lint

typeSwitchVar: 2 cases can benefit from type switch with assignment (gocritic)
case error:
return nil, v.(error)

Check failure on line 79 in vici/client_conn.go

View workflow job for this annotation

GitHub Actions / lint

S1034(related information): could eliminate this type assertion (gosimple)
case *Message:
return v.(*Message), nil

Check failure on line 81 in vici/client_conn.go

View workflow job for this annotation

GitHub Actions / lint

S1034(related information): could eliminate this type assertion (gosimple)
default:
// This is a programmer error.
return nil, fmt.Errorf("%v: invalid packet read", errEncoding)
}
return p, nil
}
}

func (cc *clientConn) awaitPacketWrite(p *packet) <-chan error {
func (cc *clientConn) awaitPacketWrite(m *Message) <-chan error {
r := make(chan error, 1)
buf := bytes.NewBuffer([]byte{})

go func() {
defer close(r)
b, err := p.bytes()
b, err := m.encode()
if err != nil {
r <- err
return
Expand All @@ -111,32 +117,34 @@ func (cc *clientConn) awaitPacketWrite(p *packet) <-chan error {
return r
}

func (cc *clientConn) awaitPacketRead() <-chan *packet {
r := make(chan *packet, 1)
func (cc *clientConn) awaitPacketRead() <-chan any {
r := make(chan any, 1)

go func() {
defer close(r)
p := &packet{}
m := NewMessage()

buf := make([]byte, headerLength)
_, err := io.ReadFull(cc.conn, buf)
if err != nil {
p.err = err
r <- p
r <- err
return
}
pl := binary.BigEndian.Uint32(buf)

buf = make([]byte, int(pl))
_, err = io.ReadFull(cc.conn, buf)
if err != nil {
p.err = err
r <- p
r <- err
return
}

if err := m.decode(buf); err != nil {
r <- err
return
}

p.parse(buf)
r <- p
r <- m
}()

return r
Expand Down
32 changes: 20 additions & 12 deletions vici/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type eventListener struct {

// Packet channel used to communicate event registration
// results.
pc chan *packet
pc chan *Message

muChans sync.Mutex
chans map[chan<- Event]struct{}
Expand Down Expand Up @@ -65,7 +65,7 @@ type Event struct {
func newEventListener(cc *clientConn) *eventListener {
el := &eventListener{
cc: cc,
pc: make(chan *packet, 4),
pc: make(chan *Message, 4),
chans: make(map[chan<- Event]struct{}),
}

Expand Down Expand Up @@ -93,18 +93,18 @@ func (el *eventListener) listen() {
defer el.closeAllChans()

for {
p, err := el.cc.packetRead(context.Background())
m, err := el.cc.packetRead(context.Background())
if err != nil {
return
}

ts := time.Now()

switch p.ptype {
switch m.header.ptype {
case pktEvent:
e := Event{
Name: p.name,
Message: p.msg,
Name: m.header.name,
Message: m,
Timestamp: ts,
}

Expand All @@ -114,7 +114,7 @@ func (el *eventListener) listen() {
// requests from the event listener. Forward them over
// the packet channel.
case pktEventConfirm, pktEventUnknown:
el.pc <- p
el.pc <- m
}
}
}
Expand Down Expand Up @@ -219,25 +219,33 @@ func (el *eventListener) unregisterEvents(events []string, all bool) error {
}

func (el *eventListener) eventRequest(ptype uint8, event string) error {
p := newPacket(ptype, event, nil)
m := &Message{
header: &struct {
ptype uint8
name string
}{
ptype: ptype,
name: event,
},
}

if err := el.cc.packetWrite(context.Background(), p); err != nil {
if err := el.cc.packetWrite(context.Background(), m); err != nil {
return err
}

// The response packet is read by listen(), and written over pc.
p, ok := <-el.pc
m, ok := <-el.pc
if !ok {
return io.ErrClosedPipe
}

switch p.ptype {
switch m.header.ptype {
case pktEventConfirm:
return nil
case pktEventUnknown:
return fmt.Errorf("%v: %v", errEventUnknown, event)
default:
return fmt.Errorf("%v: %v", errUnexpectedResponse, p.ptype)
return fmt.Errorf("%v: %v", errUnexpectedResponse, m.header.ptype)
}
}

Expand Down
114 changes: 105 additions & 9 deletions vici/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,32 @@ const (
msgListEnd
)

const (
// A name request message
pktCmdRequest uint8 = iota

// An unnamed response message for a request
pktCmdResponse

// An unnamed response if requested command is unknown
pktCmdUnkown

// A named event registration request
pktEventRegister

// A name event deregistration request
pktEventUnregister

// An unnamed response for successful event (de-)registration
pktEventConfirm

// An unnamed response if event (de-)registration failed
pktEventUnknown

// A named event message
pktEvent
)

var (
// Generic encoding/decoding and marshaling/unmarshaling errors
errEncoding = errors.New("vici: error encoding message")
Expand All @@ -69,6 +95,7 @@ var (
errMalformedMessage = errors.New("vici: malformed message")

// Malformed message errors
errBadName = fmt.Errorf("%v: expected name length does not match actual length", errDecoding)
errBadKey = fmt.Errorf("%v: expected key length does not match actual length", errMalformedMessage)
errBadValue = fmt.Errorf("%v: expected value length does not match actual length", errMalformedMessage)
errEndOfBuffer = fmt.Errorf("%v: unexpected end of buffer", errMalformedMessage)
Expand Down Expand Up @@ -97,16 +124,21 @@ var (
// for convenience, and may have rules on how they are converted to an appropriate internal message
// element type. See Message.Set and MarshalMessage for details.
type Message struct {
// Packet header. Set only for reading and writing message packets.
header *struct {
ptype uint8
name string
}
keys []string

data map[string]any
}

// NewMessage returns an empty Message.
func NewMessage() *Message {
return &Message{
keys: make([]string, 0),
data: make(map[string]any),
header: nil,
keys: make([]string, 0),
data: make(map[string]any),
}
}

Expand Down Expand Up @@ -225,6 +257,33 @@ func (m *Message) Err() error {
return nil
}

// packetIsNamed returns a bool indicating the packet is a named type
func (m *Message) packetIsNamed() bool {
if m.header == nil {
return false
}

switch m.header.ptype {
case /* Named packet types */
pktCmdRequest,
pktEventRegister,
pktEventUnregister,
pktEvent:

return true

case /* Un-named packet types */
pktCmdResponse,
pktCmdUnkown,
pktEventConfirm,
pktEventUnknown:

return false
}

return false
}

func (m *Message) addItem(key string, value any) error {
// Check if the key is already set in the message
_, exists := m.data[key]
Expand Down Expand Up @@ -317,6 +376,24 @@ func safePutUint32(buf *bytes.Buffer, val int) error {
func (m *Message) encode() ([]byte, error) {
buf := bytes.NewBuffer([]byte{})

if m.header != nil {
if err := buf.WriteByte(m.header.ptype); err != nil {
return nil, fmt.Errorf("%v: %v", errEncoding, err)
}

if m.packetIsNamed() {
err := safePutUint8(buf, len(m.header.name))
if err != nil {
return nil, fmt.Errorf("%v: %v", errEncoding, err)
}

_, err = buf.WriteString(m.header.name)
if err != nil {
return nil, fmt.Errorf("%v: %v", errEncoding, err)
}
}
}

for k, v := range m.elements() {
rv := reflect.ValueOf(v)

Expand Down Expand Up @@ -369,14 +446,38 @@ func (m *Message) encode() ([]byte, error) {
}

func (m *Message) decode(data []byte) error {
m.header = &struct {
ptype uint8
name string
}{}
buf := bytes.NewBuffer(data)

// Parse the message header first.
b, err := buf.ReadByte()
if err != nil && err != io.EOF {
if err != nil {
return fmt.Errorf("%v: %v", errDecoding, err)
}
m.header.ptype = b

if m.packetIsNamed() {
l, err := buf.ReadByte()
if err != nil {
return fmt.Errorf("%v: %v", errDecoding, err)
}

if name := buf.Next(int(l)); len(name) != int(l) {
return errBadName
} else {

Check failure on line 470 in vici/message.go

View workflow job for this annotation

GitHub Actions / lint

indent-error-flow: if block ends with a return statement, so drop this else and outdent its block (move short variable declaration to its own line if necessary) (revive)
m.header.name = string(name)
}
}

for buf.Len() > 0 {
b, err = buf.ReadByte()
if err != nil && err != io.EOF {
return fmt.Errorf("%v: %v", errDecoding, err)
}

// Determine the next message element
switch b {
case msgKeyValue:
Expand All @@ -400,11 +501,6 @@ func (m *Message) decode(data []byte) error {
}
buf.Next(n)
}

b, err = buf.ReadByte()
if err != nil && err != io.EOF {
return fmt.Errorf("%v: %v", errDecoding, err)
}
}

return nil
Expand Down
Loading

0 comments on commit 668206e

Please sign in to comment.