Skip to content

Commit

Permalink
http: send headers with Sec-WebSocket-* capitalization
Browse files Browse the repository at this point in the history
This commit changes headers capitalization sent by ws.Dialer.
Now we send headers with Sec-WebSocket-* capitalization, but still don't being picky 
about which capitalization we receive.
qguv authored and gobwas committed Apr 27, 2019
1 parent a6b9551 commit 7338e26
Showing 5 changed files with 79 additions and 33 deletions.
10 changes: 5 additions & 5 deletions dialer.go
Original file line number Diff line number Diff line change
@@ -371,14 +371,14 @@ func (d Dialer) Upgrade(conn io.ReadWriter, u *url.URL) (br *bufio.Reader, hs Ha
}

switch btsToString(k) {
case headerUpgrade:
case headerUpgradeCanonical:
headerSeen |= headerSeenUpgrade
if !bytes.Equal(v, specHeaderValueUpgrade) && !btsEqualFold(v, specHeaderValueUpgrade) {
err = ErrHandshakeBadUpgrade
return
}

case headerConnection:
case headerConnectionCanonical:
headerSeen |= headerSeenConnection
// Note that as RFC6455 says:
// > A |Connection| header field with value "Upgrade".
@@ -389,14 +389,14 @@ func (d Dialer) Upgrade(conn io.ReadWriter, u *url.URL) (br *bufio.Reader, hs Ha
return
}

case headerSecAccept:
case headerSecAcceptCanonical:
headerSeen |= headerSeenSecAccept
if !checkAcceptFromNonce(v, nonce) {
err = ErrHandshakeBadSecAccept
return
}

case headerSecProtocol:
case headerSecProtocolCanonical:
// RFC6455 1.3:
// "The server selects one or none of the acceptable protocols
// and echoes that value in its handshake to indicate that it has
@@ -414,7 +414,7 @@ func (d Dialer) Upgrade(conn io.ReadWriter, u *url.URL) (br *bufio.Reader, hs Ha
return
}

case headerSecExtensions:
case headerSecExtensionsCanonical:
hs.Extensions, err = matchSelectedExtensions(v, d.Extensions, hs.Extensions)
if err != nil {
return
2 changes: 1 addition & 1 deletion example/autobahn/autobahn.go
Original file line number Diff line number Diff line change
@@ -33,7 +33,7 @@ func main() {

ln, err := net.Listen("tcp", *addr)
if err != nil {
log.Fatalf("listen %q error: %v", err)
log.Fatalf("listen %q error: %v", *addr, err)
}
log.Printf("listening %s (%q)", ln.Addr(), *addr)

25 changes: 17 additions & 8 deletions http.go
Original file line number Diff line number Diff line change
@@ -39,14 +39,23 @@ var (
)

var (
headerHost = textproto.CanonicalMIMEHeaderKey("Host")
headerUpgrade = textproto.CanonicalMIMEHeaderKey("Upgrade")
headerConnection = textproto.CanonicalMIMEHeaderKey("Connection")
headerSecVersion = textproto.CanonicalMIMEHeaderKey("Sec-Websocket-Version")
headerSecProtocol = textproto.CanonicalMIMEHeaderKey("Sec-Websocket-Protocol")
headerSecExtensions = textproto.CanonicalMIMEHeaderKey("Sec-Websocket-Extensions")
headerSecKey = textproto.CanonicalMIMEHeaderKey("Sec-Websocket-Key")
headerSecAccept = textproto.CanonicalMIMEHeaderKey("Sec-Websocket-Accept")
headerHost = "Host"
headerUpgrade = "Upgrade"
headerConnection = "Connection"
headerSecVersion = "Sec-WebSocket-Version"
headerSecProtocol = "Sec-WebSocket-Protocol"
headerSecExtensions = "Sec-WebSocket-Extensions"
headerSecKey = "Sec-WebSocket-Key"
headerSecAccept = "Sec-WebSocket-Accept"

headerHostCanonical = textproto.CanonicalMIMEHeaderKey(headerHost)
headerUpgradeCanonical = textproto.CanonicalMIMEHeaderKey(headerUpgrade)
headerConnectionCanonical = textproto.CanonicalMIMEHeaderKey(headerConnection)
headerSecVersionCanonical = textproto.CanonicalMIMEHeaderKey(headerSecVersion)
headerSecProtocolCanonical = textproto.CanonicalMIMEHeaderKey(headerSecProtocol)
headerSecExtensionsCanonical = textproto.CanonicalMIMEHeaderKey(headerSecExtensions)
headerSecKeyCanonical = textproto.CanonicalMIMEHeaderKey(headerSecKey)
headerSecAcceptCanonical = textproto.CanonicalMIMEHeaderKey(headerSecAccept)
)

var (
26 changes: 13 additions & 13 deletions server.go
Original file line number Diff line number Diff line change
@@ -159,13 +159,13 @@ func (u HTTPUpgrader) Upgrade(r *http.Request, w http.ResponseWriter) (conn net.
err = ErrHandshakeBadProtocol
} else if r.Host == "" {
err = ErrHandshakeBadHost
} else if u := httpGetHeader(r.Header, headerUpgrade); u != "websocket" && !strEqualFold(u, "websocket") {
} else if u := httpGetHeader(r.Header, headerUpgradeCanonical); u != "websocket" && !strEqualFold(u, "websocket") {
err = ErrHandshakeBadUpgrade
} else if c := httpGetHeader(r.Header, headerConnection); c != "Upgrade" && !strHasToken(c, "upgrade") {
} else if c := httpGetHeader(r.Header, headerConnectionCanonical); c != "Upgrade" && !strHasToken(c, "upgrade") {
err = ErrHandshakeBadConnection
} else if nonce = httpGetHeader(r.Header, headerSecKey); len(nonce) != nonceSize {
} else if nonce = httpGetHeader(r.Header, headerSecKeyCanonical); len(nonce) != nonceSize {
err = ErrHandshakeBadSecKey
} else if v := httpGetHeader(r.Header, headerSecVersion); v != "13" {
} else if v := httpGetHeader(r.Header, headerSecVersionCanonical); v != "13" {
// According to RFC6455:
//
// If this version does not match a version understood by the server,
@@ -190,7 +190,7 @@ func (u HTTPUpgrader) Upgrade(r *http.Request, w http.ResponseWriter) (conn net.
}
}
if check := u.Protocol; err == nil && check != nil {
ps := r.Header[headerSecProtocol]
ps := r.Header[headerSecProtocolCanonical]
for i := 0; i < len(ps) && err == nil && hs.Protocol == ""; i++ {
var ok bool
hs.Protocol, ok = strSelectProtocol(ps[i], check)
@@ -200,7 +200,7 @@ func (u HTTPUpgrader) Upgrade(r *http.Request, w http.ResponseWriter) (conn net.
}
}
if check := u.Extension; err == nil && check != nil {
xs := r.Header[headerSecExtensions]
xs := r.Header[headerSecExtensionsCanonical]
for i := 0; i < len(xs) && err == nil; i++ {
var ok bool
hs.Extensions, ok = strSelectExtensions(xs[i], hs.Extensions, check)
@@ -466,39 +466,39 @@ func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) {
}

switch btsToString(k) {
case headerHost:
case headerHostCanonical:
headerSeen |= headerSeenHost
if onHost := u.OnHost; onHost != nil {
err = onHost(v)
}

case headerUpgrade:
case headerUpgradeCanonical:
headerSeen |= headerSeenUpgrade
if !bytes.Equal(v, specHeaderValueUpgrade) && !btsEqualFold(v, specHeaderValueUpgrade) {
err = ErrHandshakeBadUpgrade
}

case headerConnection:
case headerConnectionCanonical:
headerSeen |= headerSeenConnection
if !bytes.Equal(v, specHeaderValueConnection) && !btsHasToken(v, specHeaderValueConnectionLower) {
err = ErrHandshakeBadConnection
}

case headerSecVersion:
case headerSecVersionCanonical:
headerSeen |= headerSeenSecVersion
if !bytes.Equal(v, specHeaderValueSecVersion) {
err = ErrHandshakeUpgradeRequired
}

case headerSecKey:
case headerSecKeyCanonical:
headerSeen |= headerSeenSecKey
if len(v) != nonceSize {
err = ErrHandshakeBadSecKey
} else {
copy(nonce[:], v)
}

case headerSecProtocol:
case headerSecProtocolCanonical:
if custom, check := u.ProtocolCustom, u.Protocol; hs.Protocol == "" && (custom != nil || check != nil) {
var ok bool
if custom != nil {
@@ -511,7 +511,7 @@ func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) {
}
}

case headerSecExtensions:
case headerSecExtensionsCanonical:
if custom, check := u.ExtensionCustom, u.Extension; custom != nil || check != nil {
var ok bool
if custom != nil {
49 changes: 43 additions & 6 deletions server_test.go
Original file line number Diff line number Diff line change
@@ -36,6 +36,7 @@ type upgradeCase struct {
nonce []byte
removeSecKey bool
badSecKey bool
secKeyHeader string

req *http.Request
res *http.Response
@@ -58,8 +59,23 @@ var upgradeCases = []upgradeCase{
}),
},
{
label: "lowercase",
nonce: mustMakeNonce(),
label: "base_canonical",
nonce: mustMakeNonce(),
secKeyHeader: headerSecKeyCanonical,
req: mustMakeRequest("GET", "ws://example.org", http.Header{
headerUpgrade: []string{"websocket"},
headerConnection: []string{"Upgrade"},
headerSecVersionCanonical: []string{"13"},
}),
res: mustMakeResponse(101, http.Header{
headerUpgrade: []string{"websocket"},
headerConnection: []string{"Upgrade"},
}),
},
{
label: "lowercase_headers",
nonce: mustMakeNonce(),
secKeyHeader: strings.ToLower(headerSecKey),
req: mustMakeRequest("GET", "ws://example.org", http.Header{
strings.ToLower(headerUpgrade): []string{"websocket"},
strings.ToLower(headerConnection): []string{"Upgrade"},
@@ -101,6 +117,24 @@ var upgradeCases = []upgradeCase{
}),
hs: Handshake{Protocol: "b"},
},
{
label: "subproto_lowercase_headers",
protocol: SelectFromSlice([]string{"b", "d"}),
nonce: mustMakeNonce(),
secKeyHeader: strings.ToLower(headerSecKey),
req: mustMakeRequest("GET", "ws://example.org", http.Header{
strings.ToLower(headerUpgrade): []string{"websocket"},
strings.ToLower(headerConnection): []string{"Upgrade"},
strings.ToLower(headerSecVersion): []string{"13"},
strings.ToLower(headerSecProtocol): []string{"a", "b", "c", "d"},
}),
res: mustMakeResponse(101, http.Header{
headerUpgrade: []string{"websocket"},
headerConnection: []string{"Upgrade"},
headerSecProtocol: []string{"b"},
}),
hs: Handshake{Protocol: "b"},
},
{
label: "subproto_comma",
protocol: SelectFromSlice([]string{"b", "d"}),
@@ -332,10 +366,13 @@ func TestHTTPUpgrader(t *testing.T) {
if test.badSecKey {
nonce = nonce[:nonceSize-1]
}
test.req.Header.Set(headerSecKey, string(nonce))
if test.secKeyHeader == "" {
test.secKeyHeader = headerSecKey
}
test.req.Header[test.secKeyHeader] = []string{string(nonce)}
}
if test.err == nil {
test.res.Header.Set(headerSecAccept, string(makeAccept(test.nonce)))
test.res.Header[headerSecAccept] = []string{string(makeAccept(test.nonce))}
}

// Need to emulate http server read request for truth test.
@@ -398,10 +435,10 @@ func TestUpgrader(t *testing.T) {
if test.badSecKey {
nonce = nonce[:nonceSize-1]
}
test.req.Header.Set(headerSecKey, string(nonce))
test.req.Header[headerSecKey] = []string{string(nonce)}
}
if test.err == nil {
test.res.Header.Set(headerSecAccept, string(makeAccept(test.nonce)))
test.res.Header[headerSecAccept] = []string{string(makeAccept(test.nonce))}
}

u := Upgrader{

0 comments on commit 7338e26

Please sign in to comment.