Skip to content

Commit

Permalink
ws: reduce unsafe usage (#80)
Browse files Browse the repository at this point in the history
This commit removes stack-based slice usage from from ReadHeader(),
WriteHeader() and nonce type manipulation. That stack-based hacks were
working well on Linux but were crashing applications on Windows with
errors indicating that GC found pointer not to the allocated span.

For sure, allocation on heap will bring some penalty to the performance,
but since Go uses TCMalloc under the hood that penalties could be not
significant.

Fixes #73 #63 #60
  • Loading branch information
gobwas authored Jun 4, 2019
1 parent 78de805 commit 584c339
Show file tree
Hide file tree
Showing 10 changed files with 104 additions and 183 deletions.
18 changes: 8 additions & 10 deletions cipher.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package ws

import "unsafe"
import (
"unsafe"
)

// Cipher applies XOR cipher to the payload using mask.
// Offset is used to cipher chunked data (e.g. in io.Reader implementations).
Expand Down Expand Up @@ -40,15 +42,11 @@ func Cipher(payload []byte, mask [4]byte, offset int) {
m := *(*uint32)(unsafe.Pointer(&mask))
m2 := uint64(m)<<32 | uint64(m)

// Get pointer to payload at ln index to
// skip manual processed bytes above.
p := uintptr(unsafe.Pointer(&payload[ln]))
// Also skip right part as the division by 8 remainder.
// Divide it by 8 to get number of uint64 parts remaining to process.
n = (n - rn) >> 3
// Process the rest of bytes as uint64.
for i := 0; i < n; i, p = i+1, p+8 {
v := (*uint64)(unsafe.Pointer(p))
// Skip already processed right part.
// Get number of uint64 parts remaining to process.
n = (n - ln - rn) >> 3
for i := 0; i < n; i++ {
v := (*uint64)(unsafe.Pointer(&payload[ln+(i<<3)]))
*v = *v ^ m2
}
}
Expand Down
45 changes: 37 additions & 8 deletions cipher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,55 @@ import (
)

func TestCipher(t *testing.T) {
for i, test := range []struct {
in []byte
mask [4]byte
}{
type test struct {
name string
in []byte
mask [4]byte
offset int
}
cases := []test{
{
name: "simple",
in: []byte("Hello, XOR!"),
mask: [4]byte{1, 2, 3, 4},
},
{
name: "simple",
in: []byte("Hello, XOR!"),
mask: [4]byte{255, 255, 255, 255},
},
} {
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
}
for offset := 0; offset < 4; offset++ {
for tail := 0; tail < 8; tail++ {
for b64 := 0; b64 < 3; b64++ {
var (
ln = remain[offset]
rn = tail
n = b64*8 + ln + rn
)

p := make([]byte, n)
rand.Read(p)

var m [4]byte
rand.Read(m[:])

cases = append(cases, test{
in: p,
mask: m,
offset: offset,
})
}
}
}
for _, test := range cases {
t.Run(test.name, func(t *testing.T) {
// naive implementation of xor-cipher
exp := cipherNaive(test.in, test.mask, 0)
exp := cipherNaive(test.in, test.mask, test.offset)

res := make([]byte, len(test.in))
copy(res, test.in)
Cipher(res, test.mask, 0)
Cipher(res, test.mask, test.offset)

if !reflect.DeepEqual(res, exp) {
t.Errorf("Cipher(%v, %v):\nact:\t%v\nexp:\t%v\n", test.in, test.mask, res, exp)
Expand Down
4 changes: 1 addition & 3 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,7 @@ func (d Dialer) Upgrade(conn io.ReadWriter, u *url.URL) (br *bufio.Reader, hs Ha
}
}()

// Stick nonce bytes to the stack.
var n nonce
nonce := n.bytes()
nonce := make([]byte, nonceSize)
initNonce(nonce)

httpWriteUpgradeRequest(bw, u, nonce, d.Protocols, d.Extensions, d.Header)
Expand Down
7 changes: 6 additions & 1 deletion http.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,12 @@ func httpWriteUpgradeRequest(
httpWriteHeaderBts(bw, headerUpgrade, specHeaderValueUpgrade)
httpWriteHeaderBts(bw, headerConnection, specHeaderValueConnection)
httpWriteHeaderBts(bw, headerSecVersion, specHeaderValueSecVersion)
httpWriteHeaderBts(bw, headerSecKey, nonce[:])

// NOTE: write nonce bytes as a string to prevent heap allocation –
// WriteString() copy given string into its inner buffer, unlike Write()
// which may write p directly to the underlying io.Writer – which in turn
// will lead to p escape.
httpWriteHeader(bw, headerSecKey, btsToString(nonce))

if len(protocols) > 0 {
httpWriteHeaderKey(bw, headerSecProtocol)
Expand Down
92 changes: 20 additions & 72 deletions nonce.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
package ws

import (
"bufio"
"bytes"
"crypto/sha1"
"encoding/base64"
"fmt"
"hash"
"io"
"math/rand"
"reflect"
"sync"
"unsafe"
)

const (
Expand All @@ -29,41 +25,9 @@ const (
acceptSize = 28 // base64.StdEncoding.EncodedLen(sha1.Size)
)

var webSocketMagic = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")

var sha1Pool sync.Pool

// nonce helps to put nonce bytes on the stack and then retrieve stack-backed
// slice with unsafe.
type nonce [nonceSize]byte

// bytes returns slice of bytes backed by nonce array.
// Note that returned slice is only valid until nonce array is alive.
func (n *nonce) bytes() (bts []byte) {
h := (*reflect.SliceHeader)(unsafe.Pointer(&bts))
*h = reflect.SliceHeader{
Data: uintptr(unsafe.Pointer(n)),
Len: len(n),
Cap: len(n),
}
return bts
}

func acquireSha1() hash.Hash {
if h := sha1Pool.Get(); h != nil {
return h.(hash.Hash)
}
return sha1.New()
}

func releaseSha1(h hash.Hash) {
h.Reset()
sha1Pool.Put(h)
}

// initNonce fills given slice with random base64-encoded nonce bytes.
func initNonce(dst []byte) {
// NOTE: bts does not escapes.
// NOTE: bts does not escape.
bts := make([]byte, nonceKeySize)
if _, err := rand.Read(bts); err != nil {
panic(fmt.Sprintf("rand read error: %s", err))
Expand All @@ -85,48 +49,32 @@ func checkAcceptFromNonce(accept, nonce []byte) bool {

// initAcceptFromNonce fills given slice with accept bytes generated from given
// nonce bytes. Given buffer should be exactly acceptSize bytes.
func initAcceptFromNonce(dst, nonce []byte) {
if len(dst) != acceptSize {
func initAcceptFromNonce(accept, nonce []byte) {
const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"

if len(accept) != acceptSize {
panic("accept buffer is invalid")
}
if len(nonce) != nonceSize {
panic("nonce is invalid")
}

sha := acquireSha1()
defer releaseSha1(sha)
p := make([]byte, nonceSize+len(magic))
copy(p[:nonceSize], nonce)
copy(p[nonceSize:], magic)

sha.Write(nonce)
sha.Write(webSocketMagic)
sum := sha1.Sum(p)
base64.StdEncoding.Encode(accept, sum[:])

var (
sb [sha1.Size]byte
sum []byte
)
sh := (*reflect.SliceHeader)(unsafe.Pointer(&sum))
*sh = reflect.SliceHeader{
Data: uintptr(unsafe.Pointer(&sb)),
Len: 0,
Cap: len(sb),
}
sum = sha.Sum(sum)

base64.StdEncoding.Encode(dst, sum)
return
}

func writeAccept(w io.Writer, nonce []byte) (int, error) {
var (
b [acceptSize]byte
bts []byte
)
bh := (*reflect.SliceHeader)(unsafe.Pointer(&bts))
*bh = reflect.SliceHeader{
Data: uintptr(unsafe.Pointer(&b)),
Len: len(b),
Cap: len(b),
}

initAcceptFromNonce(bts, nonce)

return w.Write(bts)
func writeAccept(bw *bufio.Writer, nonce []byte) (int, error) {
accept := make([]byte, acceptSize)
initAcceptFromNonce(accept, nonce)
// NOTE: write accept bytes as a string to prevent heap allocation –
// WriteString() copy given string into its inner buffer, unlike Write()
// which may write p directly to the underlying io.Writer – which in turn
// will lead to p escape.
return bw.WriteString(btsToString(accept))
}
11 changes: 11 additions & 0 deletions nonce_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package ws

import "testing"

func BenchmarkInitAcceptFromNonce(b *testing.B) {
dst := make([]byte, acceptSize)
nonce := mustMakeNonce()
for i := 0; i < b.N; i++ {
initAcceptFromNonce(dst, nonce)
}
}
21 changes: 2 additions & 19 deletions read.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"encoding/binary"
"fmt"
"io"
"reflect"
"unsafe"
)

// Errors used by frame reader.
Expand All @@ -21,22 +19,7 @@ func ReadHeader(r io.Reader) (h Header, err error) {
// The maximum header size is 14, but due to the 2 hop reads,
// after first hop that reads first 2 constant bytes, we could reuse 2 bytes.
// So 14 - 2 = 12.
//
// We use unsafe to stick bts to stack and avoid allocations.
//
// Using stack based slice is safe here, cause golang docs for io.Reader
// says that "Implementations must not retain p".
// See https://golang.org/pkg/io/#Reader
var (
b [MaxHeaderSize - 2]byte // Stack based array.
bts []byte // Slice of bytes backed by stack based array.
)
bh := (*reflect.SliceHeader)(unsafe.Pointer(&bts))
*bh = reflect.SliceHeader{
Data: uintptr(unsafe.Pointer(&b)),
Len: 2,
Cap: len(b),
}
bts := make([]byte, 2, MaxHeaderSize-2)

// Prepare to hold first 2 bytes to choose size of next read.
_, err = io.ReadFull(r, bts)
Expand Down Expand Up @@ -76,7 +59,7 @@ func ReadHeader(r io.Reader) (h Header, err error) {
}

// Increase len of bts to extra bytes need to read.
// Overwrite frist 2 bytes read before.
// Overwrite first 2 bytes that was read before.
bts = bts[:extra]
_, err = io.ReadFull(r, bts)
if err != nil {
Expand Down
5 changes: 3 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,8 @@ func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) {
// headerSeen reports which header was seen by setting corresponding
// bit on.
headerSeen byte
nonce nonce

nonce = make([]byte, nonceSize)
)
for err == nil {
line, e := readLine(br)
Expand Down Expand Up @@ -585,7 +586,7 @@ func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) {
return
}

httpWriteResponseUpgrade(bw, nonce.bytes(), hs, header.WriteTo)
httpWriteResponseUpgrade(bw, nonce, hs, header.WriteTo)
err = bw.Flush()

return
Expand Down
Loading

0 comments on commit 584c339

Please sign in to comment.