Skip to content

Commit

Permalink
fix for checkptr (#102)
Browse files Browse the repository at this point in the history
* fix for checkptr
* remove unsafe methods

Fixes: #99
cristaloleg authored Mar 3, 2020
1 parent a2feb15 commit d1714d5
Showing 5 changed files with 11 additions and 105 deletions.
7 changes: 5 additions & 2 deletions cipher.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ws

import (
"encoding/binary"
"unsafe"
)

@@ -46,8 +47,10 @@ func Cipher(payload []byte, mask [4]byte, offset int) {
// Get number of uint64 parts remaining to process.
n = (n - ln - rn) >> 3
for i := 0; i < n; i++ {
v := (*uint64)(unsafe.Pointer(&payload[ln+(i<<3)]))
*v = *v ^ m2
idx := ln + (i << 3)
p := binary.LittleEndian.Uint64(payload[idx : idx+8])
p = p ^ m2
binary.LittleEndian.PutUint64(payload[idx:idx+8], p)
}
}

4 changes: 2 additions & 2 deletions dialer.go
Original file line number Diff line number Diff line change
@@ -371,7 +371,7 @@ func (d Dialer) Upgrade(conn io.ReadWriter, u *url.URL) (br *bufio.Reader, hs Ha
switch btsToString(k) {
case headerUpgradeCanonical:
headerSeen |= headerSeenUpgrade
if !bytes.Equal(v, specHeaderValueUpgrade) && !btsEqualFold(v, specHeaderValueUpgrade) {
if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) {
err = ErrHandshakeBadUpgrade
return
}
@@ -382,7 +382,7 @@ func (d Dialer) Upgrade(conn io.ReadWriter, u *url.URL) (br *bufio.Reader, hs Ha
// > A |Connection| header field with value "Upgrade".
// That is, in server side, "Connection" header could contain
// multiple token. But in response it must contains exactly one.
if !bytes.Equal(v, specHeaderValueConnection) && !btsEqualFold(v, specHeaderValueConnection) {
if !bytes.Equal(v, specHeaderValueConnection) && !bytes.EqualFold(v, specHeaderValueConnection) {
err = ErrHandshakeBadConnection
return
}
5 changes: 3 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@ import (
"io"
"net"
"net/http"
"strings"
"time"

"github.com/gobwas/httphead"
@@ -159,7 +160,7 @@ func (u HTTPUpgrader) Upgrade(r *http.Request, w http.ResponseWriter) (conn net.
err = ErrHandshakeBadProtocol
} else if r.Host == "" {
err = ErrHandshakeBadHost
} else if u := httpGetHeader(r.Header, headerUpgradeCanonical); u != "websocket" && !strEqualFold(u, "websocket") {
} else if u := httpGetHeader(r.Header, headerUpgradeCanonical); u != "websocket" && !strings.EqualFold(u, "websocket") {
err = ErrHandshakeBadUpgrade
} else if c := httpGetHeader(r.Header, headerConnectionCanonical); c != "Upgrade" && !strHasToken(c, "upgrade") {
err = ErrHandshakeBadConnection
@@ -475,7 +476,7 @@ func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) {

case headerUpgradeCanonical:
headerSeen |= headerSeenUpgrade
if !bytes.Equal(v, specHeaderValueUpgrade) && !btsEqualFold(v, specHeaderValueUpgrade) {
if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) {
err = ErrHandshakeBadUpgrade
}

43 changes: 1 addition & 42 deletions util.go
Original file line number Diff line number Diff line change
@@ -113,7 +113,7 @@ func strHasToken(header, token string) (has bool) {

func btsHasToken(header, token []byte) (has bool) {
httphead.ScanTokens(header, func(v []byte) bool {
has = btsEqualFold(v, token)
has = bytes.EqualFold(v, token)
return !has
})
return
@@ -199,47 +199,6 @@ func readLine(br *bufio.Reader) ([]byte, error) {
}
}

// strEqualFold checks s to be case insensitive equal to p.
// Note that p must be only ascii letters. That is, every byte in p belongs to
// range ['a','z'] or ['A','Z'].
func strEqualFold(s, p string) bool {
return btsEqualFold(strToBytes(s), strToBytes(p))
}

// btsEqualFold checks s to be case insensitive equal to p.
// Note that p must be only ascii letters. That is, every byte in p belongs to
// range ['a','z'] or ['A','Z'].
func btsEqualFold(s, p []byte) bool {
if len(s) != len(p) {
return false
}
n := len(s)
// Prepare manual conversion on bytes that not lay in uint64.
m := n % 8
for i := 0; i < m; i++ {
if s[i]|toLower != p[i]|toLower {
return false
}
}
// Iterate over uint64 parts of s.
n = (n - m) >> 3
if n == 0 {
// There are no more bytes to compare.
return true
}

for i := 0; i < n; i++ {
x := m + (i << 3)
av := *(*uint64)(unsafe.Pointer(&s[x]))
bv := *(*uint64)(unsafe.Pointer(&p[x]))
if av|toLower8 != bv|toLower8 {
return false
}
}

return true
}

func min(a, b int) int {
if a < b {
return a
57 changes: 0 additions & 57 deletions util_test.go
Original file line number Diff line number Diff line change
@@ -4,7 +4,6 @@ import (
"bufio"
"bytes"
"context"
"flag"
"fmt"
"io"
"math/rand"
@@ -18,8 +17,6 @@ import (
"time"
)

var compareWithStd = flag.Bool("std", false, "compare with standard library implementation (if exists)")

var readLineCases = []struct {
label string
in string
@@ -352,60 +349,6 @@ var equalFoldCases = []equalFoldCase{
inequalAt(randomEqualLetters(512), 256),
}

func TestStrEqualFold(t *testing.T) {
for i, test := range equalFoldCases {
t.Run(fmt.Sprintf("%s#%d", test.label, i), func(t *testing.T) {
if len(test.a) < 100 && len(test.b) < 100 {
t.Logf("\n\ta: %s\n\tb: %s\n", test.a, test.b)
}
exp := strings.EqualFold(test.a, test.b)
if act := strEqualFold(test.a, test.b); act != exp {
t.Errorf("strEqualFold(%q, %q) = %v; want %v", test.a, test.b, act, exp)
}
})
}
}

func BenchmarkStrEqualFold(b *testing.B) {
for i, bench := range equalFoldCases {
b.Run(fmt.Sprintf("%s#%d", bench.label, i), func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = strEqualFold(bench.a, bench.b)
}
})
}
if *compareWithStd {
for i, bench := range equalFoldCases {
b.Run(fmt.Sprintf("%s#%d_std", bench.label, i), func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = strings.EqualFold(bench.a, bench.b)
}
})
}
}
}

func BenchmarkBtsEqualFold(b *testing.B) {
for i, bench := range equalFoldCases {
ab, bb := []byte(bench.a), []byte(bench.b)
b.Run(fmt.Sprintf("%s#%d", bench.label, i), func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = btsEqualFold(ab, bb)
}
})
}
if *compareWithStd {
for i, bench := range equalFoldCases {
ab, bb := []byte(bench.a), []byte(bench.b)
b.Run(fmt.Sprintf("%s#%d_std", bench.label, i), func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = bytes.EqualFold(ab, bb)
}
})
}
}
}

func TestAsciiToInt(t *testing.T) {
for _, test := range []struct {
bts []byte

0 comments on commit d1714d5

Please sign in to comment.