Skip to content

Commit

Permalink
Closes upgraded connection on copy completion (#1670)
Browse files Browse the repository at this point in the history
Proxy copies data via `io.Copy` between upgraded request and backend connections in
two goroutines and waits for both copy calls to complete:

```
wait {
        copy backend to request // 1
        copy request to backend // 2
}
```

When backend connection is closed, first copy call completes
but the second is still blocked on read from request connection until
it is closed (i.e. client disconnects).

This change waits for either copy to complete and closes both request and
backend connections and thus unblocks the second copy.

Fixes #1669

Signed-off-by: Alexander Yastrebov <[email protected]>
  • Loading branch information
AlexanderYastrebov authored Jan 26, 2021
1 parent 657ad4e commit f5f2314
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 34 deletions.
28 changes: 15 additions & 13 deletions proxy/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"net/http/httputil"
"net/url"
"strings"
"sync"

log "github.com/sirupsen/logrus"
)
Expand Down Expand Up @@ -151,19 +150,22 @@ func (p *upgradeProxy) serveHTTP(w http.ResponseWriter, req *http.Request) {
return
}

var wg sync.WaitGroup
wg.Add(2)
done := make(chan struct{}, 2)

if p.useAuditLog {
copyAsync(&wg, backendConn, requestHijackedConn, p.auditLogOut)
copyAsync("backend->request+audit", backendConn, io.MultiWriter(requestHijackedConn, p.auditLogOut), done)
} else {
copyAsync(&wg, backendConn, requestHijackedConn)
copyAsync("backend->request", backendConn, requestHijackedConn, done)
}

copyAsync(&wg, requestHijackedConn, backendConn)
copyAsync("request->backend", requestHijackedConn, backendConn, done)

log.Debugf("Successfully upgraded to protocol %s by user request", getUpgradeRequest(req))
// Wait for goroutine to finish, such that the established connection does not break.
wg.Wait()

// Wait for either copyAsync to complete.
// Return from this method closes both request and backend connections via defer
// and thus unblocks the second copyAsync.
<-done

if p.useAuditLog {
select {
Expand Down Expand Up @@ -203,14 +205,14 @@ func (p *upgradeProxy) dialBackend(req *http.Request) (net.Conn, error) {
}
}

func copyAsync(wg *sync.WaitGroup, src io.Reader, dst ...io.Writer) {
func copyAsync(dir string, src io.Reader, dst io.Writer, done chan<- struct{}) {
go func() {
w := io.MultiWriter(dst...)
_, err := io.Copy(w, src)
_, err := io.Copy(dst, src)
// net: errClosing not exported https://github.com/golang/go/issues/4373
if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
log.Errorf("error proxying data from src to dst: %v", err)
log.Errorf("error copying data %s: %v", dir, err)
}
wg.Done()
done <- struct{}{}
}()
}

Expand Down
56 changes: 35 additions & 21 deletions proxy/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"net/url"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -91,6 +91,7 @@ func TestServeHTTP(t *testing.T) {
route string
method string
backendClosesConnection bool
backendHangs bool
noBackend bool
backendStatusCode int
expectedResponseBody string
Expand Down Expand Up @@ -135,8 +136,21 @@ func TestServeHTTP(t *testing.T) {
expectedResponseBody: "BACKEND ERROR",
backendClosesConnection: true,
},
{
msg: "backend hangs",
route: `route: Path("/ws") -> "%s";`,
method: http.MethodGet,
backendStatusCode: http.StatusSwitchingProtocols,
backendHangs: true,
},
} {
t.Run(ti.msg, func(t *testing.T) {
ti := ti // trick race detector
var clientConnClosed atomic.Value
clientAlive := func() bool {
return clientConnClosed.Load() == nil
}

backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(ti.backendStatusCode)
if ti.backendClosesConnection {
Expand All @@ -157,12 +171,18 @@ func TestServeHTTP(t *testing.T) {
return
}
defer conn.Close()

for {
s, err := bufrw.ReadString('\n')
if err != nil {
if err != nil && clientAlive() {
t.Errorf("error reading string: %v", err)
return
}

if ti.backendHangs {
return // will close connection without response
}

var resp string
if strings.Compare(s, "ping\n") == 0 {
resp = "pong\n"
Expand All @@ -171,12 +191,12 @@ func TestServeHTTP(t *testing.T) {
}

_, err = bufrw.WriteString(resp)
if err != nil {
if err != nil && clientAlive() {
t.Error(err)
return
}
err = bufrw.Flush()
if err != nil {
if err != nil && clientAlive() {
t.Error(err)
return
}
Expand Down Expand Up @@ -205,7 +225,10 @@ func TestServeHTTP(t *testing.T) {
t.Error(err)
return
}
defer conn.Close()
defer func() {
clientConnClosed.Store(true)
conn.Close()
}()

u, _ := url.ParseRequestURI("wss://www.example.org/ws")
r := &http.Request{
Expand Down Expand Up @@ -245,7 +268,7 @@ func TestServeHTTP(t *testing.T) {
if ti.method == http.MethodPost || ti.noBackend {
return
}
t.Errorf("wrong response status <%d>, expeted <%d>", resp.StatusCode, ti.backendStatusCode)
t.Errorf("wrong response status <%d>, expected <%d>", resp.StatusCode, ti.backendStatusCode)
return
}

Expand All @@ -255,6 +278,12 @@ func TestServeHTTP(t *testing.T) {
return
}
pong, err := reader.ReadString('\n')
if ti.backendHangs {
if err != io.EOF {
t.Error("expected EOF on closed connection read")
}
return
}
if err != nil {
t.Error(err)
return
Expand Down Expand Up @@ -323,21 +352,6 @@ func TestInvalidHTTPDialBackend(t *testing.T) {
}
}

func TestCopyAsync(t *testing.T) {
var dst bytes.Buffer
var wg sync.WaitGroup
wg.Add(1)
s := "foo"
src := bytes.NewBufferString(s)

copyAsync(&wg, src, &dst)
wg.Wait()
res := dst.String()
if res != s {
t.Errorf("%s != %s after copy", res, s)
}
}

func TestAuditLogging(t *testing.T) {
message := strconv.Itoa(rand.Int())
test := func(enabled bool, check func(*testing.T, *bytes.Buffer, *bytes.Buffer)) func(t *testing.T) {
Expand Down

0 comments on commit f5f2314

Please sign in to comment.