Skip to content

Commit

Permalink
message: add safePutUint{8,16,32} helpers
Browse files Browse the repository at this point in the history
Add helpers to check for overflow before writing an integer in big
endian form to a buffer. This helps address potential integer overflow
bugs.
  • Loading branch information
enr0n committed Aug 24, 2024
1 parent 92a0383 commit 8c27dfa
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 24 deletions.
5 changes: 1 addition & 4 deletions vici/client_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,7 @@ func (cc *clientConn) awaitPacketWrite(p *packet) <-chan error {
}

// Write the packet length
pl := make([]byte, headerLength)
binary.BigEndian.PutUint32(pl, uint32(len(b)))
_, err = buf.Write(pl)
if err != nil {
if err := safePutUint32(buf, len(b)); err != nil {
r <- err
return
}
Expand Down
13 changes: 9 additions & 4 deletions vici/client_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ func TestPacketWrite(t *testing.T) {

length := binary.BigEndian.Uint32(b)

// #nosec G115
if want := len(goldNamedPacketBytes); length != uint32(want) {
t.Errorf("Unexpected packet length: got %d, expected: %d", length, want)
}
Expand Down Expand Up @@ -108,13 +109,17 @@ func TestPacketRead(t *testing.T) {
}()

// Make a buffer big enough for the data and the header.
b := make([]byte, headerLength+len(goldNamedPacketBytes))
buf := new(bytes.Buffer)

binary.BigEndian.PutUint32(b[:headerLength], uint32(len(goldNamedPacketBytes)))
if err := safePutUint32(buf, len(goldNamedPacketBytes)); err != nil {
t.Fatalf("Unexpected error writing header: %v", err)
}

copy(b[headerLength:], goldNamedPacketBytes)
if _, err := buf.Write(goldNamedPacketBytes); err != nil {
t.Fatalf("Unexpected error writing packet: %v", err)
}

_, err := srvr.Write(b)
_, err := srvr.Write(buf.Bytes())
if err != nil {
t.Fatalf("Unexpected error sending bytes: %v", err)
}
Expand Down
73 changes: 58 additions & 15 deletions vici/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,58 @@ func (m *Message) elements() []messageElement {
return ordered
}

func safePutUint8(buf *bytes.Buffer, val int) error {
limit := ^uint8(0)

if int64(val) > int64(limit) {
return fmt.Errorf("val too long (%d > %d)", val, limit)
}

// We can safely convert now, because we just checked that it will not overflow.
// #nosec G115
if err := buf.WriteByte(uint8(val)); err != nil {
return err
}

return nil
}

func safePutUint16(buf *bytes.Buffer, val int) error {
limit := ^uint16(0)
b := make([]byte, 2)

if int64(val) > int64(limit) {
return fmt.Errorf("val too long (%d > %d)", val, limit)
}

// We can safely convert now, because we just checked that it will not overflow.
binary.BigEndian.PutUint16(b, uint16(val)) // #nosec G115

if _, err := buf.Write(b); err != nil {
return err
}

return nil
}

func safePutUint32(buf *bytes.Buffer, val int) error {
limit := ^uint32(0)
b := make([]byte, 4)

if int64(val) > int64(limit) {
return fmt.Errorf("val too long (%d > %d)", val, limit)
}

// We can safely convert now, because we just checked that it will not overflow.
binary.BigEndian.PutUint32(b, uint32(val)) // #nosec G115

if _, err := buf.Write(b); err != nil {
return err
}

return nil
}

func (m *Message) encode() ([]byte, error) {
buf := bytes.NewBuffer([]byte{})

Expand Down Expand Up @@ -372,22 +424,17 @@ func (m *Message) encodeKeyValue(key, value string) ([]byte, error) {
buf := bytes.NewBuffer([]byte{msgKeyValue})

// Write the key length and key
err := buf.WriteByte(uint8(len(key)))
if err != nil {
if err := safePutUint8(buf, len(key)); err != nil {
return nil, fmt.Errorf("%v: %v", errEncoding, err)
}

_, err = buf.WriteString(key)
_, err := buf.WriteString(key)
if err != nil {
return nil, fmt.Errorf("%v: %v", errEncoding, err)
}

// Write the value's length to the buffer as two bytes
vl := make([]byte, 2)
binary.BigEndian.PutUint16(vl, uint16(len(value)))

_, err = buf.Write(vl)
if err != nil {
if err := safePutUint16(buf, len(value)); err != nil {
return nil, fmt.Errorf("%v: %v", errEncoding, err)
}

Expand All @@ -412,7 +459,7 @@ func (m *Message) encodeList(key string, list []string) ([]byte, error) {
buf := bytes.NewBuffer([]byte{msgListStart})

// Write the key length and key
err := buf.WriteByte(uint8(len(key)))
err := safePutUint8(buf, len(key))
if err != nil {
return nil, fmt.Errorf("%v: %v", errEncoding, err)
}
Expand All @@ -430,11 +477,7 @@ func (m *Message) encodeList(key string, list []string) ([]byte, error) {
}

// Write the item's length to the buffer as two bytes
il := make([]byte, 2)
binary.BigEndian.PutUint16(il, uint16(len(item)))

_, err = buf.Write(il)
if err != nil {
if err := safePutUint16(buf, len(item)); err != nil {
return nil, fmt.Errorf("%v: %v", errEncoding, err)
}

Expand All @@ -461,7 +504,7 @@ func (m *Message) encodeSection(key string, section *Message) ([]byte, error) {
buf := bytes.NewBuffer([]byte{msgSectionStart})

// Write the key length and key
err := buf.WriteByte(uint8(len(key)))
err := safePutUint8(buf, len(key))
if err != nil {
return nil, fmt.Errorf("%v: %v", errEncoding, err)
}
Expand Down
2 changes: 1 addition & 1 deletion vici/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (p *packet) bytes() ([]byte, error) {

// Write the name, preceded by its length
if p.isNamed() {
err := buf.WriteByte(uint8(len(p.name)))
err := safePutUint8(buf, len(p.name))
if err != nil {
return nil, fmt.Errorf("%v: %v", errPacketWrite, err)
}
Expand Down
18 changes: 18 additions & 0 deletions vici/packet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,21 @@ func TestPacketBytes(t *testing.T) {
t.Fatalf("Encoded packet does not equal gold bytes.\nExpected: %v\nReceived: %v", goldUnnamedPacketBytes, b)
}
}

func TestPacketTooLong(t *testing.T) {
tooLong := make([]byte, 256)

for i := range tooLong {
tooLong[i] = 'a'
}

p := &packet{
ptype: pktCmdRequest,
name: string(tooLong),
}

_, err := p.bytes()
if err == nil {
t.Fatalf("Expected packet-too-long error due to %s", p.name)
}
}

0 comments on commit 8c27dfa

Please sign in to comment.