Skip to content

Commit

Permalink
feat: support feeling client disconnetion (#1054)
Browse files Browse the repository at this point in the history
Co-authored-by: Ricky-chen1 <Ricky-chen1@[email protected]>
Co-authored-by: kinggo <[email protected]>
  • Loading branch information
3 people authored May 28, 2024
1 parent 413ba29 commit 0fe1182
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 25 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ require (
github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7
github.com/bytedance/mockey v1.2.1
github.com/bytedance/sonic v1.8.1
github.com/cloudwego/netpoll v0.5.0
github.com/cloudwego/netpoll v0.6.0
github.com/fsnotify/fsnotify v1.5.4
github.com/tidwall/gjson v1.14.4
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ github.com/bytedance/sonic v1.8.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZX
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/cloudwego/netpoll v0.5.0 h1:oRrOp58cPCvK2QbMozZNDESvrxQaEHW2dCimmwH1lcU=
github.com/cloudwego/netpoll v0.5.0/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ=
github.com/cloudwego/netpoll v0.6.0 h1:JRMkrA1o8k/4quxzg6Q1XM+zIhwZsyoWlq6ef+ht31U=
github.com/cloudwego/netpoll v0.6.0/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down
55 changes: 55 additions & 0 deletions pkg/app/server/hertz_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
c "github.com/cloudwego/hertz/pkg/app/client"
"github.com/cloudwego/hertz/pkg/common/test/assert"
"github.com/cloudwego/hertz/pkg/common/utils"
"github.com/cloudwego/hertz/pkg/network"
"github.com/cloudwego/hertz/pkg/network/standard"
"github.com/cloudwego/hertz/pkg/protocol/consts"
"golang.org/x/sys/unix"
Expand Down Expand Up @@ -134,3 +135,57 @@ func TestHertz_Spin(t *testing.T) {

<-ch2
}

func TestWithSenseClientDisconnection(t *testing.T) {
var closeFlag int32
h := New(WithHostPorts("127.0.0.1:6631"), WithSenseClientDisconnection(true))
h.GET("/ping", func(c context.Context, ctx *app.RequestContext) {
assert.DeepEqual(t, "aa", string(ctx.Host()))
ch := make(chan struct{})
select {
case <-c.Done():
atomic.StoreInt32(&closeFlag, 1)
assert.DeepEqual(t, context.Canceled, c.Err())
case <-ch:
}
})
go h.Spin()
time.Sleep(time.Second)
con, err := net.Dial("tcp", "127.0.0.1:6631")
assert.Nil(t, err)
_, err = con.Write([]byte("GET /ping HTTP/1.1\r\nHost: aa\r\n\r\n"))
assert.Nil(t, err)
time.Sleep(time.Second)
assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(0))
assert.Nil(t, con.Close())
time.Sleep(time.Second)
assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(1))
}

func TestWithSenseClientDisconnectionAndWithOnConnect(t *testing.T) {
var closeFlag int32
h := New(WithHostPorts("127.0.0.1:6632"), WithSenseClientDisconnection(true), WithOnConnect(func(ctx context.Context, conn network.Conn) context.Context {
return ctx
}))
h.GET("/ping", func(c context.Context, ctx *app.RequestContext) {
assert.DeepEqual(t, "aa", string(ctx.Host()))
ch := make(chan struct{})
select {
case <-c.Done():
atomic.StoreInt32(&closeFlag, 1)
assert.DeepEqual(t, context.Canceled, c.Err())
case <-ch:
}
})
go h.Spin()
time.Sleep(time.Second)
con, err := net.Dial("tcp", "127.0.0.1:6632")
assert.Nil(t, err)
_, err = con.Write([]byte("GET /ping HTTP/1.1\r\nHost: aa\r\n\r\n"))
assert.Nil(t, err)
time.Sleep(time.Second)
assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(0))
assert.Nil(t, con.Close())
time.Sleep(time.Second)
assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(1))
}
16 changes: 16 additions & 0 deletions pkg/app/server/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,19 @@ func WithDisableDefaultContentType(disable bool) config.Option {
o.NoDefaultContentType = disable
}}
}

// WithSenseClientDisconnection sets the ability to sense client disconnections.
// If we don't set it, it will default to false.
// There are two issues to note when using this option:
// 1. Warning: It only applies to netpoll.
// 2. After opening, the context.Context in the request will be cancelled.
//
// Example:
// server.Default(
// server.WithSenseClientDisconnection(true),
// )
func WithSenseClientDisconnection(b bool) config.Option {
return config.Option{F: func(o *config.Options) {
o.SenseClientDisconnection = b
}}
}
3 changes: 3 additions & 0 deletions pkg/app/server/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func TestOptions(t *testing.T) {
WithBasePath("/"),
WithMaxRequestBodySize(2),
WithDisablePrintRoute(true),
WithSenseClientDisconnection(true),
WithNetwork("unix"),
WithExitWaitTime(time.Second),
WithMaxKeepBodySize(500),
Expand Down Expand Up @@ -93,6 +94,7 @@ func TestOptions(t *testing.T) {
assert.DeepEqual(t, opt.BasePath, "/")
assert.DeepEqual(t, opt.MaxRequestBodySize, 2)
assert.DeepEqual(t, opt.DisablePrintRoute, true)
assert.DeepEqual(t, opt.SenseClientDisconnection, true)
assert.DeepEqual(t, opt.Network, "unix")
assert.DeepEqual(t, opt.ExitWaitTimeout, time.Second)
assert.DeepEqual(t, opt.MaxKeepBodySize, 500)
Expand Down Expand Up @@ -130,6 +132,7 @@ func TestDefaultOptions(t *testing.T) {
assert.DeepEqual(t, opt.GetOnly, false)
assert.DeepEqual(t, opt.DisableKeepalive, false)
assert.DeepEqual(t, opt.DisablePrintRoute, false)
assert.DeepEqual(t, opt.SenseClientDisconnection, false)
assert.DeepEqual(t, opt.Network, "tcp")
assert.DeepEqual(t, opt.ExitWaitTimeout, time.Second*5)
assert.DeepEqual(t, opt.MaxKeepBodySize, 4*1024*1024)
Expand Down
4 changes: 4 additions & 0 deletions pkg/common/config/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ type Options struct {
StreamRequestBody bool
NoDefaultServerHeader bool
DisablePrintRoute bool
SenseClientDisconnection bool
Network string
Addr string
BasePath string
Expand Down Expand Up @@ -203,6 +204,9 @@ func NewOptions(opts []Option) *Options {
// Disabled when set to True
DisablePrintRoute: false,

// The ability to sense client disconnection is disabled by default
SenseClientDisconnection: false,

// "tcp", "udp", "unix"(unix domain socket)
Network: defaultNetwork,

Expand Down
1 change: 1 addition & 0 deletions pkg/common/config/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ func TestDefaultOptions(t *testing.T) {
assert.False(t, options.RemoveExtraSlash)
assert.True(t, options.UnescapePathValues)
assert.False(t, options.DisablePreParseMultipartForm)
assert.False(t, options.SenseClientDisconnection)
assert.DeepEqual(t, defaultNetwork, options.Network)
assert.DeepEqual(t, defaultAddr, options.Addr)
assert.DeepEqual(t, defaultMaxRequestBodySize, options.MaxRequestBodySize)
Expand Down
69 changes: 47 additions & 22 deletions pkg/network/netpoll/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,33 +36,45 @@ func init() {
netpoll.SetLoggerOutput(io.Discard)
}

type ctxCancelKeyStruct struct{}

var ctxCancelKey = ctxCancelKeyStruct{}

func cancelContext(ctx context.Context) context.Context {
ctx, cancel := context.WithCancel(ctx)
ctx = context.WithValue(ctx, ctxCancelKey, cancel)
return ctx
}

type transporter struct {
sync.RWMutex
network string
addr string
keepAliveTimeout time.Duration
readTimeout time.Duration
writeTimeout time.Duration
listener net.Listener
eventLoop netpoll.EventLoop
listenConfig *net.ListenConfig
OnAccept func(conn net.Conn) context.Context
OnConnect func(ctx context.Context, conn network.Conn) context.Context
senseClientDisconnection bool
network string
addr string
keepAliveTimeout time.Duration
readTimeout time.Duration
writeTimeout time.Duration
listener net.Listener
eventLoop netpoll.EventLoop
listenConfig *net.ListenConfig
OnAccept func(conn net.Conn) context.Context
OnConnect func(ctx context.Context, conn network.Conn) context.Context
}

// For transporter switch
func NewTransporter(options *config.Options) network.Transporter {
return &transporter{
network: options.Network,
addr: options.Addr,
keepAliveTimeout: options.KeepAliveTimeout,
readTimeout: options.ReadTimeout,
writeTimeout: options.WriteTimeout,
listener: nil,
eventLoop: nil,
listenConfig: options.ListenConfig,
OnAccept: options.OnAccept,
OnConnect: options.OnConnect,
senseClientDisconnection: options.SenseClientDisconnection,
network: options.Network,
addr: options.Addr,
keepAliveTimeout: options.KeepAliveTimeout,
readTimeout: options.ReadTimeout,
writeTimeout: options.WriteTimeout,
listener: nil,
eventLoop: nil,
listenConfig: options.ListenConfig,
OnAccept: options.OnAccept,
OnConnect: options.OnConnect,
}
}

Expand All @@ -88,10 +100,14 @@ func (t *transporter) ListenAndServe(onReq network.OnData) (err error) {
if t.writeTimeout > 0 {
conn.SetWriteTimeout(t.writeTimeout)
}
ctx := context.Background()
if t.OnAccept != nil {
return t.OnAccept(newConn(conn))
ctx = t.OnAccept(newConn(conn))
}
if t.senseClientDisconnection {
ctx = cancelContext(ctx)
}
return context.Background()
return ctx
}),
}

Expand All @@ -101,6 +117,15 @@ func (t *transporter) ListenAndServe(onReq network.OnData) (err error) {
}))
}

if t.senseClientDisconnection {
opts = append(opts, netpoll.WithOnDisconnect(func(ctx context.Context, connection netpoll.Connection) {
cancelFunc, ok := ctx.Value(ctxCancelKey).(context.CancelFunc)
if cancelFunc != nil && ok {
cancelFunc()
}
}))
}

// Create EventLoop
t.Lock()
t.eventLoop, err = netpoll.NewEventLoop(func(ctx context.Context, connection netpoll.Connection) error {
Expand Down
29 changes: 29 additions & 0 deletions pkg/network/netpoll/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,35 @@ func TestTransport(t *testing.T) {
assert.Assert(t, atomic.LoadInt32(&onDataFlag) == 1)
})

t.Run("TestSenseClientDisconnection", func(t *testing.T) {
var onReqFlag int32
transporter := NewTransporter(&config.Options{
Addr: addr,
Network: nw,
SenseClientDisconnection: true,
})

go transporter.ListenAndServe(func(ctx context.Context, conn interface{}) error {
atomic.StoreInt32(&onReqFlag, 1)
time.Sleep(100 * time.Millisecond)
assert.DeepEqual(t, context.Canceled, ctx.Err())
return nil
})
defer transporter.Close()
time.Sleep(100 * time.Millisecond)

dial := NewDialer()
conn, err := dial.DialConnection(nw, addr, time.Second, nil)
assert.Nil(t, err)
_, err = conn.Write([]byte("123"))
assert.Nil(t, err)
err = conn.Close()
assert.Nil(t, err)
time.Sleep(100 * time.Millisecond)

assert.Assert(t, atomic.LoadInt32(&onReqFlag) == 1)
})

t.Run("TestListenConfig", func(t *testing.T) {
listenCfg := &net.ListenConfig{Control: func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
Expand Down

0 comments on commit 0fe1182

Please sign in to comment.