Skip to content

Commit

Permalink
perf: optimize reconnect
Browse files Browse the repository at this point in the history
  • Loading branch information
Sherlock-Holo committed Jan 23, 2020
1 parent d4b1516 commit eb8cdd7
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 152 deletions.
117 changes: 0 additions & 117 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,117 +112,8 @@ func New(cfg *config.Config) (*Server, error) {
}

return server, nil

/*if server.webCertificateIsEnabled() {
// load web certificate
webCertificate, err := tls.LoadX509KeyPair(server.config.WebCrt, server.config.WebKey)
if err != nil {
log.Fatalf("%+v", xerrors.Errorf("read web key pair failed: %w", err))
}
tlsConfig.Certificates = append(tlsConfig.Certificates, webCertificate)
}
if server.reverseProxyIsEnabled() {
// load read reverse certificate
reverseProxyCertificate, err := tls.LoadX509KeyPair(server.config.ReverseProxyCrt, server.config.ReverseProxyKey)
if err != nil {
log.Fatalf("%+v", xerrors.Errorf("read reverse proxy key pair failed: %w", err))
}
tlsConfig.Certificates = append(tlsConfig.Certificates, reverseProxyCertificate)
}
tlsConfig.BuildNameToCertificate()
server.tlsListener, err = tls.Listen("tcp", server.config.ListenAddr, tlsConfig)
if err != nil {
log.Fatalf("%+v", xerrors.Errorf("listen %s failed: %w", server.config.ListenAddr, err))
}
mux := http.NewServeMux()
if server.webCertificateIsEnabled() {
mux.HandleFunc(server.config.Host+"/", server.proxyHandle)
mux.Handle(server.config.WebHost+"/", enableGzip(http.HandlerFunc(server.webHandle)))
} else {
mux.HandleFunc(server.config.Host+"/", server.checkRequest)
}
// enable reverse proxy
if server.reverseProxyIsEnabled() {
if !strings.HasPrefix(server.config.ReverseProxyAddr, "http") && !strings.HasPrefix(server.config.ReverseProxyAddr, "https") {
server.config.ReverseProxyAddr = "http://" + server.config.ReverseProxyAddr
}
u, err := url.Parse(server.config.ReverseProxyAddr)
if err != nil {
log.Fatalf("%+v", xerrors.Errorf("parse reverse proxy addr failed: %w", err))
}
proxy := httputil.NewSingleHostReverseProxy(u)
originDirector := proxy.Director
proxy.Director = func(r *http.Request) {
originDirector(r)
// delete origin field to avoid websocket upgrade check failed
r.Header.Del("origin")
}
mux.Handle(server.config.ReverseProxyHost+"/", proxy)
}
server.httpServer = http.Server{Handler: mux}
return*/
}

/*func (s *Server) checkRequest(w http.ResponseWriter, r *http.Request) {
code := r.Header.Get("totp-code")
ok, err := utils.VerifyCode(code, s.config.Secret, s.config.Period)
if err != nil {
http.Error(w, "server internal error", http.StatusInternalServerError)
log.Warnf("%+v", xerrors.Errorf("verify code error: %w", err))
return
}
if !ok || !websocket.IsWebSocketUpgrade(r) {
s.webHandle(w, r)
return
}
s.proxyHandle(w, r)
}

func (s *Server) webHandle(w http.ResponseWriter, r *http.Request) {
if s.config.WebRoot == "" {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
http.FileServer(http.Dir(s.config.WebRoot)).ServeHTTP(w, r)
}
func (s *Server) proxyHandle(w http.ResponseWriter, r *http.Request) {
conn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
log.Warnf("%+v", xerrors.Errorf("websocket upgrade failed: %w", err))
return
}
linkCfg := link.DefaultConfig(link.ClientMode)
linkCfg.KeepaliveInterval = 5 * time.Second
manager := link.NewManager(wsWrapper.NewWrapper(conn), linkCfg)
for {
l, err := manager.Accept()
if err != nil {
log.Errorf("manager accept failed: %v", err)
manager.Close()
return
}
go handle(l)
}
}*/

func handle(conn net.Conn) {
address, err := libsocks.UnmarshalAddressFrom(conn)
if err != nil {
Expand Down Expand Up @@ -274,11 +165,3 @@ func (s *Server) Run() {
func (s *Server) Close() error {
return s.session.Close()
}

/*func (s *Server) webCertificateIsEnabled() bool {
return s.config.WebCrt != "" && s.config.WebKey != "" && s.config.WebRoot != "" && s.config.WebHost != ""
}
func (s *Server) reverseProxyIsEnabled() bool {
return s.config.ReverseProxyHost != "" && s.config.ReverseProxyCrt != "" && s.config.ReverseProxyKey != "" && s.config.ReverseProxyAddr != ""
}*/
75 changes: 41 additions & 34 deletions session/quic/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,33 +92,31 @@ func (q *quicClient) OpenConn(ctx context.Context) (net.Conn, error) {
q.connectMutex.Lock()
defer q.connectMutex.Unlock()

if q.quicSession == nil {
for {
log.Debug("start quic connect")
for q.quicSession == nil {
log.Debug("start quic connect")

err := q.reconnect(ctx)
switch {
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
netErr := &net.OpError{
Op: "open",
Net: q.Name(),
Err: err,
}

return nil, errors.Errorf("connect quic failed: %w", netErr)
err := q.reconnect(ctx)
switch {
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
netErr := &net.OpError{
Op: "open",
Net: q.Name(),
Err: err,
}

default:
err = errors.Errorf("quic connect failed: %w", err)
log.Warnf("%+v", err)
return nil, errors.Errorf("connect quic failed: %w", netErr)

continue
default:
err = errors.Errorf("quic connect failed: %w", err)
log.Warnf("%+v", err)

case err == nil:
}
continue

log.Debug("quic connect success")
break
case err == nil:
}

log.Debug("quic connect success")
break
}

select {
Expand Down Expand Up @@ -171,9 +169,16 @@ func (q *quicClient) reconnect(ctx context.Context) error {
_ = q.quicSession.Close()
}

var codeBuf bytes.Buffer
q.quicSession = nil

var (
codeBuf bytes.Buffer
errRet error
)

for i := 0; i < 2; i++ {
errRet = nil

session, err := quic.DialAddrContext(ctx, q.addr, q.tlsConfig, &quic.Config{
KeepAlive: true,
MaxIncomingStreams: math.MaxInt32,
Expand All @@ -182,7 +187,8 @@ func (q *quicClient) reconnect(ctx context.Context) error {
})

if err != nil {
return errors.Errorf("dial quic failed: %w", err)
errRet = errors.Errorf("dial quic failed: %w", err)
continue
}

code, err := utils.GenCode(q.secret, q.period)
Expand All @@ -200,14 +206,18 @@ func (q *quicClient) reconnect(ctx context.Context) error {

stream, err := session.OpenStreamSync(ctx)
if err != nil {
return errors.Errorf("open handshake stream failed: %w", err)
_ = session.Close()
errRet = errors.Errorf("open handshake stream failed: %w", err)

continue
}

if _, err := stream.Write(codeBytes); err != nil {
_ = stream.Close()
_ = session.Close()

return errors.Errorf("send TOTP code failed: %w", err)
errRet = errors.Errorf("send TOTP code failed: %w", err)
continue
}

log.Debug("write handshake success")
Expand All @@ -216,30 +226,27 @@ func (q *quicClient) reconnect(ctx context.Context) error {
_ = stream.Close()
_ = session.Close()

return errors.Errorf("set read deadline failed: %w", err)
errRet = errors.Errorf("set read deadline failed: %w", err)
continue
}

handshakeResp := make([]byte, 1)
if _, err := stream.Read(handshakeResp); err != nil {
_ = stream.Close()
_ = session.Close()

return errors.Errorf("get TOTP handshake response failed: %w", err)
errRet = errors.Errorf("get TOTP handshake response failed: %w", err)
continue
}

_ = stream.Close()

switch handshakeResp[0] {
case quicSession.HandshakeFailed:
_ = session.Close()

errRet = errors.New("connect failed: maybe TOTP secret is wrong")
log.Debug("handshake failed")

if i == 1 {

return errors.New("connect failed: maybe TOTP secret is wrong")
}

continue

case quicSession.HandshakeSuccess:
Expand All @@ -251,5 +258,5 @@ func (q *quicClient) reconnect(ctx context.Context) error {
break
}

return nil
return errRet
}
2 changes: 1 addition & 1 deletion session/quic/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ func (q *quicServer) sessionHandshake(ctx context.Context, session quic.Session)
return false, errors.Errorf("read handshake message length failed: %w", err)
}

length := int(buf[0])
length := buf[0]

log.Debugf("handshake length %d", length)

Expand Down

0 comments on commit eb8cdd7

Please sign in to comment.