diff --git a/example/backend-service/echo/echo.go b/example/backend-service/echo/echo.go index 2e31e22db6..7321327179 100644 --- a/example/backend-service/echo/echo.go +++ b/example/backend-service/echo/echo.go @@ -1,14 +1,17 @@ package main import ( + "bufio" "fmt" "io" + "log" + "net" "net/http" "os" "time" ) -// TeeWriter is an io.Writer wapper. +// TeeWriter is an io.Writer wrapper. type TeeWriter struct { writers []io.Writer } @@ -26,7 +29,7 @@ func (tw *TeeWriter) Write(p []byte) (n int, err error) { return len(p), nil } -func main() { +func httpServer() { echoHandler := func(w http.ResponseWriter, req *http.Request) { time.Sleep(10 * time.Millisecond) body, err := io.ReadAll(req.Body) @@ -60,3 +63,79 @@ func main() { http.ListenAndServe(":9095", nil) fmt.Println("listen and serve failed") } + +func tcpServer() { + echoHandler := func(conn net.Conn) { + time.Sleep(10 * time.Millisecond) + reader := bufio.NewReader(conn) + for { + message, err := reader.ReadString('\n') + if err != nil { + conn.Close() + return + } + fmt.Println("Message incoming: ", string(message)) + responseMsg := []byte( + "\nYour Message \n" + + "============== \n" + + "Message incoming: " + string(message) + "\n", + ) + conn.Write(responseMsg) + } + } + + listener, err := net.Listen("tcp", "127.0.0.1:9095") + if err != nil { + log.Fatal(err) + } + defer listener.Close() + + for { + conn, err := listener.Accept() + if err != nil { + log.Fatal(err) + } + go echoHandler(conn) + } +} + +func udpServer() { + echoHandler := func(pc net.PacketConn, addr net.Addr, buf []byte) { + time.Sleep(10 * time.Millisecond) + + fmt.Println("Your Message") + fmt.Println("==============") + fmt.Printf("Message incoming: %s \n", string(buf)) + + pc.WriteTo(buf, addr) + } + pc, err := net.ListenPacket("udp", ":9095") + if err != nil { + log.Fatal(err) + } + defer pc.Close() + + for { + buf := make([]byte, 1024) + n, addr, err := pc.ReadFrom(buf) + if err != nil { + continue + } + go echoHandler(pc, addr, buf[:n]) + } +} + +func main() { + protocol := "http" + if len(os.Args) > 1 { + protocol = os.Args[1] + } + switch protocol { + case "tcp": + tcpServer() + case "udp": + udpServer() + default: + httpServer() + } +} diff --git a/example/client/tcp_udp.go b/example/client/tcp_udp.go new file mode 100644 index 0000000000..c751ec2d14 --- /dev/null +++ b/example/client/tcp_udp.go @@ -0,0 +1,83 @@ +package main + +import ( + "bufio" + "fmt" + "net" + "os" +) + +func tcpClient() { + strEcho := "Hello from client! \n" + servAddr := "127.0.0.1:10080" + if len(os.Args) > 2 { + servAddr = os.Args[2] + } + tcpAddr, err := net.ResolveTCPAddr("tcp", servAddr) + if err != nil { + fmt.Println("ResolveTCPAddr failed:", err.Error()) + os.Exit(1) + } + + conn, err := net.DialTCP("tcp", nil, tcpAddr) + if err != nil { + fmt.Println("Dial failed:", err.Error()) + os.Exit(1) + } + + _, err = conn.Write([]byte(strEcho)) + if err != nil { + fmt.Println("Write to server failed:", err.Error()) + os.Exit(1) + } + + fmt.Println("write to server = ", strEcho) + + reply := make([]byte, 1024) + + _, err = conn.Read(reply) + if err != nil { + fmt.Println("Write to server failed:", err.Error()) + os.Exit(1) + } + + fmt.Println("reply from server=", string(reply)) + + _ = conn.Close() +} + +func udpClient() { + p := make([]byte, 2048) + servAddr := "127.0.0.1:10070" + if len(os.Args) > 2 { + servAddr = os.Args[2] + } + conn, err := net.Dial("udp", servAddr) + if err != nil { + fmt.Printf("Some error %v", err) + return + } + _, _ = fmt.Fprintf(conn, "Ping from client") + _, err = bufio.NewReader(conn).Read(p) + if err == nil { + fmt.Printf("%s\n", p) + } else { + fmt.Printf("Some error %v\n", err) + } + _ = conn.Close() +} + +func main() { + protocol := "" + if len(os.Args) > 1 { + protocol = os.Args[1] + } + switch protocol { + case "tcp": + tcpClient() + case "udp": + udpClient() + default: + fmt.Println("Please provide udp or tcp flag.") + } +} diff --git a/go.mod b/go.mod index 795a21bdfd..621830807c 100644 --- a/go.mod +++ b/go.mod @@ -53,6 +53,7 @@ require ( github.com/tcnksm/go-httpstat v0.2.1-0.20191008022543-e866bb274419 github.com/tidwall/gjson v1.11.0 github.com/tomasen/realip v0.0.0-20180522021738-f0c99a92ddce + github.com/valyala/bytebufferpool v1.0.0 github.com/valyala/fasttemplate v1.2.1 github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonschema v1.2.1-0.20201027075954-b076d39a02e5 diff --git a/pkg/context/httpcontext.go b/pkg/context/httpcontext.go index 35ad9b0192..da580187e5 100644 --- a/pkg/context/httpcontext.go +++ b/pkg/context/httpcontext.go @@ -80,7 +80,7 @@ type ( Method() string SetMethod(method string) - // URL + // Scheme URL Scheme() string Host() string SetHost(host string) diff --git a/pkg/object/httpserver/httpserver.go b/pkg/object/httpserver/httpserver.go index 5ec34619de..8637a42f1e 100644 --- a/pkg/object/httpserver/httpserver.go +++ b/pkg/object/httpserver/httpserver.go @@ -81,7 +81,7 @@ func (hs *HTTPServer) Inherit(superSpec *supervisor.Spec, previousGeneration sup } } -// Status is the wrapper of runtime's Status. +// Status is the wrapper of runtimes Status. func (hs *HTTPServer) Status() *supervisor.Status { return &supervisor.Status{ ObjectStatus: hs.runtime.Status(), diff --git a/pkg/object/tcpproxy/connection.go b/pkg/object/tcpproxy/connection.go new file mode 100644 index 0000000000..1d8f45fc7d --- /dev/null +++ b/pkg/object/tcpproxy/connection.go @@ -0,0 +1,400 @@ +/* + * Copyright (c) 2017, MegaEase + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tcpproxy + +import ( + "io" + "net" + "runtime/debug" + "sync" + "sync/atomic" + "time" + + "github.com/megaease/easegress/pkg/logger" + "github.com/megaease/easegress/pkg/util/fasttime" + "github.com/megaease/easegress/pkg/util/iobufferpool" + "github.com/megaease/easegress/pkg/util/timerpool" +) + +const writeBufSize = 8 + +var tcpBufferPool = sync.Pool{ + New: func() interface{} { + buf := make([]byte, iobufferpool.DefaultBufferReadCapacity) + return buf + }, +} + +// Connection wrap tcp connection +type Connection struct { + closed uint32 + rawConn net.Conn + localAddr net.Addr + remoteAddr net.Addr + + readBuffer []byte + writeBuffers net.Buffers + ioBuffers []*iobufferpool.StreamBuffer + writeBufferChan chan *iobufferpool.StreamBuffer + + mu sync.Mutex + readLoopExit chan struct{} + writeLoopExit chan struct{} + connStopChan chan struct{} // notify write loop doesn't block on read buffer channel + listenerStopChan chan struct{} // notify tcp listener has been closed, just use in read loop + + lastReadDeadlineTime time.Time + lastWriteDeadlineTime time.Time + + onRead func(buffer *iobufferpool.StreamBuffer) // execute read filters + onClose func(event ConnectionEvent) +} + +// NewClientConn wrap connection create from client +func NewClientConn(conn net.Conn, listenerStopChan chan struct{}) *Connection { + return &Connection{ + rawConn: conn, + localAddr: conn.LocalAddr(), + remoteAddr: conn.RemoteAddr(), + readLoopExit: make(chan struct{}), + writeLoopExit: make(chan struct{}), + connStopChan: make(chan struct{}), + listenerStopChan: listenerStopChan, + + mu: sync.Mutex{}, + writeBufferChan: make(chan *iobufferpool.StreamBuffer, writeBufSize), + } +} + +// SetOnRead set connection read handle +func (c *Connection) SetOnRead(onRead func(buffer *iobufferpool.StreamBuffer)) { + c.onRead = onRead +} + +// SetOnClose set close callback +func (c *Connection) SetOnClose(onclose func(event ConnectionEvent)) { + c.onClose = onclose +} + +// Start running connection read/write loop +func (c *Connection) Start() { + fnRecover := func() { + if r := recover(); r != nil { + logger.Errorf("tcp read/write loop panic: %v\n%s\n", r, string(debug.Stack())) + c.Close(NoFlush, LocalClose) + } + } + + go func() { + defer fnRecover() + c.startReadLoop() + }() + + go func() { + defer fnRecover() + c.startWriteLoop() + }() +} + +// Write receive other connection data +func (c *Connection) Write(buf *iobufferpool.StreamBuffer) (err error) { + defer func() { + if r := recover(); r != nil { + logger.Errorf("tcp connection has closed, local addr: %s, remote addr: %s, err: %+v", + c.localAddr.String(), c.remoteAddr.String(), r) + err = ErrConnectionHasClosed + } + }() + + select { + case c.writeBufferChan <- buf: + return + default: + } + + // try to send data again in 60 seconds + t := timerpool.Get(60 * time.Second) + select { + case c.writeBufferChan <- buf: + case <-t.C: + buf.Release() + err = ErrWriteBufferChanTimeout + } + timerpool.Put(t) + return +} + +func (c *Connection) startReadLoop() { + defer func() { + close(c.readLoopExit) + if c.readBuffer != nil { + tcpBufferPool.Put(c.readBuffer[:iobufferpool.DefaultBufferReadCapacity]) + } + }() + + var n int + var err error + for { + if atomic.LoadUint32(&c.closed) == 1 { + logger.Debugf("connection has been closed, exit read loop, local addr: %s, remote addr: %s", + c.localAddr.String(), c.remoteAddr.String()) + return + } + + select { + case <-c.listenerStopChan: + logger.Debugf("listener stopped, exit read loop,local addr: %s, remote addr: %s", + c.localAddr.String(), c.remoteAddr.String()) + c.Close(NoFlush, LocalClose) + return + default: + } + + if n, err = c.doReadIO(); n > 0 { + c.onRead(iobufferpool.NewStreamBuffer(c.readBuffer[:n])) + } + + if err == nil { + continue + } + if te, ok := err.(net.Error); ok && te.Timeout() { + continue // c.closed will be check in the front of read loop + } + + if err == io.EOF { + logger.Debugf("remote close connection, local addr: %s, remote addr: %s, err: %s", + c.localAddr.String(), c.remoteAddr.String(), err.Error()) + c.Close(NoFlush, RemoteClose) + } else { + logger.Errorf("error on read, local addr: %s, remote addr: %s, err: %s", + c.localAddr.String(), c.remoteAddr.String(), err.Error()) + c.Close(NoFlush, OnReadErrClose) + } + return + } +} + +func (c *Connection) startWriteLoop() { + + defer func() { + close(c.writeLoopExit) + }() + + for { + if atomic.LoadUint32(&c.closed) == 1 { + logger.Debugf("connection has been closed, exit write loop, local addr: %s, remote addr: %s", + c.localAddr.String(), c.remoteAddr.String()) + return + } + + select { + case <-c.connStopChan: + logger.Debugf("connection has been closed, exit write loop, local addr: %s, remote addr: %s", + c.localAddr.String(), c.remoteAddr.String()) + return + case buf, ok := <-c.writeBufferChan: + if !ok { + return + } + c.appendBuffer(buf) + NoMoreData: + // Keep reading until writeBufferChan is empty + // writeBufferChan may be full when writeLoop call doWrite + for i := 0; i < writeBufSize-1; i++ { + select { + case buf, ok := <-c.writeBufferChan: + if !ok { + return + } + c.appendBuffer(buf) + default: + break NoMoreData + } + } + } + + _, err := c.doWrite() + if err == nil { + continue + } + if te, ok := err.(net.Error); ok && te.Timeout() { + continue + } + + if err == io.EOF { + logger.Debugf("finish write with eof, local addr: %s, remote addr: %s", + c.localAddr.String(), c.remoteAddr.String()) + c.Close(NoFlush, LocalClose) + } else { + // remote call CloseRead, so just exit write loop, wait read loop exit + logger.Errorf("error on write, local addr: %s, remote addr: %s, err: %+v", + c.localAddr.String(), c.remoteAddr.String(), err) + } + return + } +} + +func (c *Connection) appendBuffer(buf *iobufferpool.StreamBuffer) { + if buf == nil { + return + } + c.ioBuffers = append(c.ioBuffers, buf) + c.writeBuffers = append(c.writeBuffers, buf.Bytes()) +} + +// Close connection close function +func (c *Connection) Close(ccType CloseType, event ConnectionEvent) { + defer func() { + if r := recover(); r != nil { + logger.Errorf("connection close panic, err: %+v\n%s", r, string(debug.Stack())) + } + }() + + if ccType == FlushWrite { + _ = c.Write(iobufferpool.NewEOFStreamBuffer()) // wait for write loop to call close function again + return + } + + if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { + // connection has already closed, so there is no need to execute below code + return + } + logger.Infof("enter connection close func(%s), local addr: %s, remote addr: %s", + event, c.localAddr.String(), c.remoteAddr.String()) + + _ = c.rawConn.SetDeadline(time.Now()) // notify read/write loop to exit + close(c.connStopChan) + c.onClose(event) + + go func() { + <-c.readLoopExit + <-c.writeLoopExit + // wait for read/write loop exit, then close socket(avoid exceptions caused by close operations) + _ = c.rawConn.Close() + }() +} + +func (c *Connection) doReadIO() (bufLen int, err error) { + if c.readBuffer == nil { + c.readBuffer = tcpBufferPool.Get().([]byte) + } + + // add read deadline setting optimization? + // https://github.com/golang/go/issues/15133 + curr := fasttime.Now().Add(15 * time.Second) + // there is no need to set readDeadline in too short time duration + if diff := curr.Sub(c.lastReadDeadlineTime).Milliseconds(); diff > 0 { + _ = c.rawConn.SetReadDeadline(curr) + c.lastReadDeadlineTime = curr + } + return c.rawConn.(io.Reader).Read(c.readBuffer) +} + +func (c *Connection) doWrite() (int64, error) { + curr := fasttime.Now().Add(15 * time.Second) + // there is no need to set writeDeadline in too short time duration + if diff := curr.Sub(c.lastWriteDeadlineTime).Milliseconds(); diff > 0 { + _ = c.rawConn.SetWriteDeadline(curr) + c.lastWriteDeadlineTime = curr + } + return c.doWriteIO() +} + +func (c *Connection) writeBufLen() (bufLen int) { + for _, buf := range c.writeBuffers { + bufLen += len(buf) + } + return +} + +func (c *Connection) doWriteIO() (bytesSent int64, err error) { + buffers := c.writeBuffers + bytesSent, err = buffers.WriteTo(c.rawConn) + if err != nil { + return bytesSent, err + } + + for i, buf := range c.ioBuffers { + c.ioBuffers[i] = nil + c.writeBuffers[i] = nil + if buf.EOF() { + err = io.EOF + } + buf.Release() + } + c.ioBuffers = c.ioBuffers[:0] + c.writeBuffers = c.writeBuffers[:0] + return +} + +// ServerConnection wrap tcp connection to backend server +type ServerConnection struct { + Connection + connectTimeout time.Duration +} + +// NewServerConn construct tcp server connection +func NewServerConn(connectTimeout uint32, serverAddr net.Addr, listenerStopChan chan struct{}) *ServerConnection { + conn := &ServerConnection{ + Connection: Connection{ + remoteAddr: serverAddr, + readLoopExit: make(chan struct{}), + writeLoopExit: make(chan struct{}), + connStopChan: make(chan struct{}), + writeBufferChan: make(chan *iobufferpool.StreamBuffer, writeBufSize), + + mu: sync.Mutex{}, + listenerStopChan: listenerStopChan, + }, + connectTimeout: time.Duration(connectTimeout) * time.Millisecond, + } + return conn +} + +// Connect create backend server tcp connection +func (u *ServerConnection) Connect() bool { + addr := u.remoteAddr + if addr == nil { + logger.Errorf("cannot connect because the server has been closed, server addr: %s", addr.String()) + return false + } + + timeout := u.connectTimeout + if timeout == 0 { + timeout = 10 * time.Second + } + + var err error + u.rawConn, err = net.DialTimeout("tcp", addr.String(), timeout) + if err != nil { + if err == io.EOF { + logger.Errorf("cannot connect because the server has been closed, server addr: %s", addr.String()) + } else if te, ok := err.(net.Error); ok && te.Timeout() { + logger.Errorf("connect to server timeout, server addr: %s", addr.String()) + } else { + logger.Errorf("connect to server failed, server addr: %s, err: %s", addr.String(), err.Error()) + } + return false + } + + u.localAddr = u.rawConn.LocalAddr() + _ = u.rawConn.(*net.TCPConn).SetNoDelay(true) + _ = u.rawConn.(*net.TCPConn).SetKeepAlive(true) + u.Start() + return true +} diff --git a/pkg/object/tcpproxy/constant.go b/pkg/object/tcpproxy/constant.go new file mode 100644 index 0000000000..07a6a2daaf --- /dev/null +++ b/pkg/object/tcpproxy/constant.go @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2017, MegaEase + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tcpproxy + +import ( + "errors" +) + +// CloseType represent connection close type +type CloseType string + +//Connection close types +const ( + // FlushWrite means write buffer to underlying io then close connection + FlushWrite CloseType = "FlushWrite" + // NoFlush means close connection without flushing buffer + NoFlush CloseType = "NoFlush" +) + +// ConnectionEvent type +type ConnectionEvent string + +const ( + // RemoteClose connection closed by remote + RemoteClose ConnectionEvent = "RemoteClose" + // LocalClose connection closed by local + LocalClose ConnectionEvent = "LocalClose" + // OnReadErrClose connection closed by read error + OnReadErrClose ConnectionEvent = "OnReadErrClose" + // Connected connection has been connected + Connected ConnectionEvent = "ConnectedFlag" + // ConnectTimeout connect to remote failed due to timeout + ConnectTimeout ConnectionEvent = "ConnectTimeout" + // ConnectFailed connect to remote failed + ConnectFailed ConnectionEvent = "ConnectFailed" + // OnWriteTimeout write data failed due to timeout + OnWriteTimeout ConnectionEvent = "OnWriteTimeout" +) + +var ( + // ErrConnectionHasClosed connection has been closed + ErrConnectionHasClosed = errors.New("connection has closed") + // ErrWriteBufferChanTimeout writeBufferChan has timeout + ErrWriteBufferChanTimeout = errors.New("writeBufferChan has timeout") +) + +// ConnState status +type ConnState int + +// Connection statuses +const ( + ConnInit ConnState = iota + ConnActive + ConnClosed +) diff --git a/pkg/object/tcpproxy/listener.go b/pkg/object/tcpproxy/listener.go new file mode 100644 index 0000000000..1ab5ac5664 --- /dev/null +++ b/pkg/object/tcpproxy/listener.go @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2017, MegaEase + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tcpproxy + +import ( + "fmt" + "net" + "sync" + + "github.com/megaease/easegress/pkg/logger" + "github.com/megaease/easegress/pkg/util/limitlistener" +) + +// ListenerState listener running state +type ListenerState int + +type listener struct { + name string + localAddr string // listen addr + state ListenerState // listener state + + mutex *sync.Mutex + stopChan chan struct{} + maxConns uint32 // maxConn for tcp listener + + listener *limitlistener.LimitListener // tcp listener with accept limit + onAccept func(conn net.Conn, listenerStop chan struct{}) // tcp accept handle +} + +func newListener(spec *Spec, onAccept func(conn net.Conn, listenerStop chan struct{})) *listener { + listen := &listener{ + name: spec.Name, + localAddr: fmt.Sprintf(":%d", spec.Port), + + onAccept: onAccept, + maxConns: spec.MaxConnections, + + mutex: &sync.Mutex{}, + stopChan: make(chan struct{}), + } + return listen +} + +func (l *listener) listen() error { + tl, err := net.Listen("tcp", l.localAddr) + if err != nil { + return err + } + // wrap tcp listener with accept limit + l.listener = limitlistener.NewLimitListener(tl, l.maxConns) + return nil +} + +func (l *listener) acceptEventLoop() { + + for { + tconn, err := l.listener.Accept() + if err == nil { + go l.onAccept(tconn, l.stopChan) + continue + } + + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + logger.Infof("tcp listener(%s) stop accept connection due to timeout, err: %s", + l.localAddr, nerr) + return + } + + ope, ok := err.(*net.OpError) + if !ok { + logger.Errorf("tcp listener(%s) stop accept connection with unknown error: %s.", + l.localAddr, err.Error()) + return + } + + // not timeout error and not temporary, which means the error is non-recoverable + if !(ope.Timeout() && ope.Temporary()) { + // accept error raised by sockets closing + if ope.Op == "accept" { + logger.Debugf("tcp listener(%s) closed, stop accept connection", l.localAddr) + } else { + logger.Errorf("tcp listener(%s) stop accept connection due to non-recoverable error: %s", + l.localAddr, err.Error()) + } + return + } + } +} + +func (l *listener) setMaxConnection(maxConn uint32) { + l.listener.SetMaxConnection(maxConn) +} + +func (l *listener) close() (err error) { + if l.listener != nil { + err = l.listener.Close() + } + close(l.stopChan) + return err +} diff --git a/pkg/object/tcpproxy/runtime.go b/pkg/object/tcpproxy/runtime.go new file mode 100644 index 0000000000..42b04c3670 --- /dev/null +++ b/pkg/object/tcpproxy/runtime.go @@ -0,0 +1,328 @@ +/* + * Copyright (c) 2017, MegaEase + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tcpproxy + +import ( + "fmt" + "net" + "reflect" + "sync/atomic" + "time" + + "github.com/megaease/easegress/pkg/logger" + "github.com/megaease/easegress/pkg/supervisor" + "github.com/megaease/easegress/pkg/util/iobufferpool" + "github.com/megaease/easegress/pkg/util/ipfilter" + "github.com/megaease/easegress/pkg/util/layer4backend" +) + +const ( + checkFailedTimeout = 10 * time.Second + + stateNil stateType = "nil" + stateFailed stateType = "failed" + stateRunning stateType = "running" + stateClosed stateType = "closed" +) + +var ( + errNil = fmt.Errorf("") +) + +type ( + stateType string + + eventCheckFailed struct{} + eventServeFailed struct { + startNum uint64 + err error + } + + eventReload struct { + nextSuperSpec *supervisor.Spec + } + eventClose struct{ done chan struct{} } + + runtime struct { + superSpec *supervisor.Spec + spec *Spec + + pool *layer4backend.Pool // backend servers pool + ipFilters *ipfilter.Layer4IpFilters // ip filters + listener *listener // tcp listener + + startNum uint64 + eventChan chan interface{} // receive traffic controller event + + state atomic.Value // runtime running state + err atomic.Value // runtime running error + } +) + +func newRuntime(superSpec *supervisor.Spec) *runtime { + spec := superSpec.ObjectSpec().(*Spec) + r := &runtime{ + superSpec: superSpec, + + pool: layer4backend.NewPool(superSpec.Super(), spec.Pool, ""), + ipFilters: ipfilter.NewLayer4IPFilters(spec.IPFilter), + + eventChan: make(chan interface{}, 10), + } + + r.setState(stateNil) + r.setError(errNil) + + go r.fsm() + go r.checkFailed() + return r +} + +// Close notify runtime close +func (r *runtime) Close() { + done := make(chan struct{}) + r.eventChan <- &eventClose{done: done} + <-done +} + +// FSM is the finite-state-machine for the runtime. +func (r *runtime) fsm() { + for e := range r.eventChan { + switch e := e.(type) { + case *eventCheckFailed: + r.handleEventCheckFailed() + case *eventServeFailed: + r.handleEventServeFailed(e) + case *eventReload: + r.handleEventReload(e) + case *eventClose: + r.handleEventClose(e) + // NOTE: We don't close hs.eventChan, + // in case of panic of any other goroutines + // to send event to it later. + return + default: + logger.Errorf("BUG: unknown event: %T\n", e) + } + } +} + +func (r *runtime) reload(nextSuperSpec *supervisor.Spec) { + r.superSpec = nextSuperSpec + nextSpec := nextSuperSpec.ObjectSpec().(*Spec) + r.ipFilters.ReloadRules(nextSpec.IPFilter) + r.pool.ReloadRules(nextSuperSpec.Super(), nextSpec.Pool, "") + + // r.listener does not create just after the process started and the config load for the first time. + if nextSpec != nil && r.listener != nil { + r.listener.setMaxConnection(nextSpec.MaxConnections) + } + + // NOTE: Due to the mechanism of supervisor, + // nextSpec must not be nil, just defensive programming here. + switch { + case r.spec == nil && nextSpec == nil: + logger.Errorf("BUG: nextSpec is nil") + // Nothing to do. + case r.spec == nil && nextSpec != nil: + r.spec = nextSpec + r.startServer() + case r.spec != nil && nextSpec == nil: + logger.Errorf("BUG: nextSpec is nil") + r.spec = nil + r.closeServer() + case r.spec != nil && nextSpec != nil: + if r.needRestartServer(nextSpec) { + r.spec = nextSpec + r.closeServer() + r.startServer() + } else { + r.spec = nextSpec + } + } +} + +func (r *runtime) setState(state stateType) { + r.state.Store(state) +} + +func (r *runtime) getState() stateType { + return r.state.Load().(stateType) +} + +func (r *runtime) setError(err error) { + if err == nil { + r.err.Store(errNil) + } else { + // NOTE: For type safe. + r.err.Store(fmt.Errorf("%v", err)) + } +} + +func (r *runtime) getError() error { + err := r.err.Load() + if err == nil { + return nil + } + return err.(error) +} + +func (r *runtime) needRestartServer(nextSpec *Spec) bool { + x := *r.spec + y := *nextSpec + + // The change of options below need not restart the tcp server. + x.MaxConnections, y.MaxConnections = 0, 0 + x.ConnectTimeout, y.ConnectTimeout = 0, 0 + + x.Pool, y.Pool = nil, nil + x.IPFilter, y.IPFilter = nil, nil + + // The update of rules need not to shutdown server. + return !reflect.DeepEqual(x, y) +} + +func (r *runtime) startServer() { + l := newListener(r.spec, r.onAccept()) + + r.listener = l + r.startNum++ + r.setState(stateRunning) + r.setError(nil) + + if err := l.listen(); err != nil { + r.setState(stateFailed) + r.setError(err) + logger.Errorf("tcp listener for %s failed, err: %+v", l.localAddr, err) + + _ = l.close() + r.eventChan <- &eventServeFailed{ + err: err, + startNum: r.startNum, + } + return + } + + go r.listener.acceptEventLoop() +} + +func (r *runtime) closeServer() { + if r.listener == nil { + return + } + + _ = r.listener.close() + logger.Infof("listener for %s(%s) closed", r.listener.name, r.listener.localAddr) +} + +func (r *runtime) checkFailed() { + ticker := time.NewTicker(checkFailedTimeout) + for range ticker.C { + state := r.getState() + if state == stateFailed { + r.eventChan <- &eventCheckFailed{} + } else if state == stateClosed { + ticker.Stop() + return + } + } +} + +func (r *runtime) handleEventCheckFailed() { + if r.getState() == stateFailed { + r.startServer() + } +} + +func (r *runtime) handleEventServeFailed(e *eventServeFailed) { + if r.startNum > e.startNum { + return + } + r.setState(stateFailed) + r.setError(e.err) +} + +func (r *runtime) handleEventReload(e *eventReload) { + r.reload(e.nextSuperSpec) +} + +func (r *runtime) handleEventClose(e *eventClose) { + r.closeServer() + r.pool.Close() + close(e.done) +} + +func (r *runtime) onAccept() func(conn net.Conn, listenerStop chan struct{}) { + + return func(rawConn net.Conn, listenerStop chan struct{}) { + clientIP := rawConn.RemoteAddr().(*net.TCPAddr).IP.String() + if !r.ipFilters.AllowIP(clientIP) { + _ = rawConn.Close() + logger.Infof("close tcp connection from %s to %s which ip is not allowed", + rawConn.RemoteAddr().String(), rawConn.LocalAddr().String()) + return + } + + server, err := r.pool.Next(clientIP) + if err != nil { + _ = rawConn.Close() + logger.Errorf("close tcp connection due to no available server, local addr: %s, err: %+v", + rawConn.LocalAddr(), err) + return + } + + serverAddr, _ := net.ResolveTCPAddr("tcp", server.Addr) + serverConn := NewServerConn(r.spec.ConnectTimeout, serverAddr, listenerStop) + if !serverConn.Connect() { + _ = rawConn.Close() + return + } + + clientConn := NewClientConn(rawConn, listenerStop) + r.setCallbacks(clientConn, serverConn) + clientConn.Start() // server conn start read/write loop when connect is called + } +} + +func (r *runtime) setCallbacks(clientConn *Connection, serverConn *ServerConnection) { + clientConn.SetOnRead(func(readBuf *iobufferpool.StreamBuffer) { + if readBuf != nil && readBuf.Len() > 0 { + _ = serverConn.Write(readBuf) + } + }) + serverConn.SetOnRead(func(readBuf *iobufferpool.StreamBuffer) { + if readBuf != nil && readBuf.Len() > 0 { + _ = clientConn.Write(readBuf) + } + }) + + clientConn.SetOnClose(func(event ConnectionEvent) { + if event == RemoteClose { + serverConn.Close(FlushWrite, LocalClose) + } else { + serverConn.Close(NoFlush, LocalClose) + } + }) + serverConn.SetOnClose(func(event ConnectionEvent) { + if event == RemoteClose { + clientConn.Close(FlushWrite, LocalClose) + } else { + clientConn.Close(NoFlush, LocalClose) + } + }) +} diff --git a/pkg/object/tcpproxy/spec.go b/pkg/object/tcpproxy/spec.go new file mode 100644 index 0000000000..32c979d379 --- /dev/null +++ b/pkg/object/tcpproxy/spec.go @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2017, MegaEase + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tcpproxy + +import ( + "github.com/megaease/easegress/pkg/util/ipfilter" + "github.com/megaease/easegress/pkg/util/layer4backend" +) + +type ( + // Spec describes the Layer4 Server. + Spec struct { + Name string `yaml:"name" json:"name" jsonschema:"required"` + Port uint16 `yaml:"port" json:"port" jsonschema:"required"` + + // tcp stream config params + MaxConnections uint32 `yaml:"maxConns" jsonschema:"omitempty,minimum=1"` + ConnectTimeout uint32 `yaml:"connectTimeout" jsonschema:"omitempty"` + + Pool *layer4backend.Spec `yaml:"pool" jsonschema:"required"` + IPFilter *ipfilter.Spec `yaml:"ipFilters,omitempty" jsonschema:"omitempty"` + } +) + +// Validate validates Layer4 Server. +func (spec *Spec) Validate() error { + if poolErr := spec.Pool.Validate(); poolErr != nil { + return poolErr + } + + return nil +} diff --git a/pkg/object/tcpproxy/tcpserver.go b/pkg/object/tcpproxy/tcpserver.go new file mode 100644 index 0000000000..87787c91d0 --- /dev/null +++ b/pkg/object/tcpproxy/tcpserver.go @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2017, MegaEase + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tcpproxy + +import ( + "github.com/megaease/easegress/pkg/supervisor" +) + +const ( + // Category is the category of TCPServer. + Category = supervisor.CategoryBusinessController + + // Kind is the kind of TCPServer. + Kind = "TCPServer" +) + +func init() { + supervisor.Register(&TCPServer{}) +} + +type ( + // TCPServer is Object of tcp server. + TCPServer struct { + runtime *runtime + } +) + +// Category returns the category of TCPServer. +func (l4 *TCPServer) Category() supervisor.ObjectCategory { + return Category +} + +// Kind returns the kind of TCPServer. +func (l4 *TCPServer) Kind() string { + return Kind +} + +// DefaultSpec returns the default spec of TCPServer. +func (l4 *TCPServer) DefaultSpec() interface{} { + return &Spec{ + MaxConnections: 1024, + ConnectTimeout: 5 * 1000, + } +} + +// Validate validates the tcp server structure. +func (l4 *TCPServer) Validate() error { + return nil +} + +// Init initializes TCPServer. +func (l4 *TCPServer) Init(superSpec *supervisor.Spec) { + + l4.runtime = newRuntime(superSpec) + l4.runtime.eventChan <- &eventReload{ + nextSuperSpec: superSpec, + } +} + +// Inherit inherits previous generation of TCPServer. +func (l4 *TCPServer) Inherit(superSpec *supervisor.Spec, previousGeneration supervisor.Object) { + + l4.runtime = previousGeneration.(*TCPServer).runtime + l4.runtime.eventChan <- &eventReload{ + nextSuperSpec: superSpec, + } +} + +// Status is the wrapper of runtimes Status. +func (l4 *TCPServer) Status() *supervisor.Status { + return &supervisor.Status{} +} + +// Close actually close tcp server runtime +func (l4 *TCPServer) Close() { + l4.runtime.Close() +} diff --git a/pkg/object/udpproxy/runtime.go b/pkg/object/udpproxy/runtime.go new file mode 100644 index 0000000000..d1a034ce92 --- /dev/null +++ b/pkg/object/udpproxy/runtime.go @@ -0,0 +1,224 @@ +/* + * Copyright (c) 2017, MegaEase + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package udpproxy + +import ( + "fmt" + "net" + "sync" + "time" + + "github.com/megaease/easegress/pkg/logger" + "github.com/megaease/easegress/pkg/supervisor" + "github.com/megaease/easegress/pkg/util/iobufferpool" + "github.com/megaease/easegress/pkg/util/ipfilter" + "github.com/megaease/easegress/pkg/util/layer4backend" +) + +type ( + runtime struct { + superSpec *supervisor.Spec + spec *Spec + + pool *layer4backend.Pool // backend servers pool + serverConn *net.UDPConn // listener + sessions map[string]*session + + ipFilters *ipfilter.Layer4IpFilters + + mu sync.Mutex + done chan struct{} + } +) + +func newRuntime(superSpec *supervisor.Spec) *runtime { + spec := superSpec.ObjectSpec().(*Spec) + r := &runtime{ + superSpec: superSpec, + spec: spec, + + pool: layer4backend.NewPool(superSpec.Super(), spec.Pool, ""), + ipFilters: ipfilter.NewLayer4IPFilters(spec.IPFilter), + + done: make(chan struct{}), + sessions: make(map[string]*session), + } + + r.startServer() + return r +} + +// close notify runtime close +func (r *runtime) close() { + close(r.done) + _ = r.serverConn.Close() + r.pool.Close() +} + +func (r *runtime) startServer() { + listenAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", r.spec.Port)) + if err != nil { + logger.Errorf("parse udp listen addr(%s) failed, err: %+v", r.spec.Port, err) + return + } + + r.serverConn, err = net.ListenUDP("udp", listenAddr) + if err != nil { + logger.Errorf("create udp listener(%s) failed, err: %+v", r.spec.Port, err) + return + } + + var cp *connPool + if !r.spec.HasResponse { + // if client udp request doesn't have response, use connection pool to save server connections pool + cp = newConnPool() + } + + go func() { + defer cp.close() + + buf := make([]byte, iobufferpool.UDPPacketMaxSize) + for { + n, clientAddr, err := r.serverConn.ReadFromUDP(buf[:]) + + if err != nil { + select { + case <-r.done: + return // detect weather udp server is closed + default: + } + + if ope, ok := err.(*net.OpError); ok { + // not timeout error and not temporary, which means the error is non-recoverable + if !(ope.Timeout() && ope.Temporary()) { + logger.Errorf("udp listener(%d) crashed due to non-recoverable error, err: %+v", r.spec.Port, err) + return + } + } + logger.Errorf("failed to read packet from udp connection(:%d), err: %+v", r.spec.Port, err) + continue + } + + if !r.ipFilters.AllowIP(clientAddr.IP.String()) { + logger.Debugf("discard udp packet from %s send to udp server(:%d)", clientAddr.IP.String(), r.spec.Port) + continue + } + + if !r.spec.HasResponse { + if err := r.sendOneShot(cp, clientAddr, buf[0:n]); err != nil { + logger.Errorf("%s", err.Error()) + } + continue + } + + r.proxy(clientAddr, buf[0:n]) + } + }() +} + +func (r *runtime) getServerConn(pool *connPool, clientAddr *net.UDPAddr) (net.Conn, string, error) { + server, err := r.pool.Next(clientAddr.IP.String()) + if err != nil { + return nil, "", fmt.Errorf("can not get server addr for udp connection(:%d)", r.spec.Port) + } + + var serverConn net.Conn + if pool != nil { + serverConn = pool.get(server.Addr) + if serverConn != nil { + return serverConn, server.Addr, nil + } + } + + addr, err := net.ResolveUDPAddr("udp", server.Addr) + if err != nil { + return nil, server.Addr, fmt.Errorf("parse server addr(%s) to udp addr failed, err: %+v", server.Addr, err) + } + + serverConn, err = net.DialUDP("udp", nil, addr) + if err != nil { + return nil, server.Addr, fmt.Errorf("dial to server addr(%s) failed, err: %+v", server.Addr, err) + } + + if pool != nil { + pool.put(server.Addr, serverConn) + } + return serverConn, server.Addr, nil +} + +func (r *runtime) sendOneShot(pool *connPool, clientAddr *net.UDPAddr, buf []byte) error { + serverConn, serverAddr, err := r.getServerConn(pool, clientAddr) + if err != nil { + return err + } + + n, err := serverConn.Write(buf) + if err != nil { + return fmt.Errorf("sned data to %s failed, err: %+v", serverAddr, err) + } + + if n != len(buf) { + return fmt.Errorf("failed to send full packet to %s, read %d but send %d", serverAddr, len(buf), n) + } + return nil +} + +func (r *runtime) getSession(clientAddr *net.UDPAddr) (*session, error) { + key := clientAddr.String() + + r.mu.Lock() + defer r.mu.Unlock() + + s, ok := r.sessions[key] + if ok && !s.isClosed() { + return s, nil + } + + serverConn, serverAddr, err := r.getServerConn(nil, clientAddr) + if err != nil { + return nil, err + } + + onClose := func() { + r.mu.Lock() + delete(r.sessions, key) + r.mu.Unlock() + } + s = newSession(clientAddr, serverAddr, serverConn, r.done, onClose, + time.Duration(r.spec.ServerIdleTimeout)*time.Millisecond, + time.Duration(r.spec.ClientIdleTimeout)*time.Millisecond) + s.ListenResponse(r.serverConn) + + r.sessions[key] = s + return s, nil +} + +func (r *runtime) proxy(clientAddr *net.UDPAddr, buf []byte) { + s, err := r.getSession(clientAddr) + if err != nil { + logger.Errorf("%s", err.Error()) + return + } + + dup := iobufferpool.UDPBufferPool.Get().([]byte) + n := copy(dup, buf) + err = s.Write(&iobufferpool.Packet{Payload: dup[:n], Len: n}) + if err != nil { + logger.Errorf("write data to udp session(%s) failed, err: %v", clientAddr.IP.String(), err) + } +} diff --git a/pkg/object/udpproxy/session.go b/pkg/object/udpproxy/session.go new file mode 100644 index 0000000000..a99d7336d1 --- /dev/null +++ b/pkg/object/udpproxy/session.go @@ -0,0 +1,225 @@ +/* + * Copyright (c) 2017, MegaEase + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package udpproxy + +import ( + "fmt" + "net" + "sync" + "time" + + "github.com/megaease/easegress/pkg/logger" + "github.com/megaease/easegress/pkg/util/fasttime" + "github.com/megaease/easegress/pkg/util/iobufferpool" + "github.com/megaease/easegress/pkg/util/timerpool" +) + +type ( + session struct { + clientAddr *net.UDPAddr + serverAddr string + serverConn net.Conn + clientIdleTimeout time.Duration + serverIdleTimeout time.Duration + writeBuf chan *iobufferpool.Packet + + stopped bool + stopChan chan struct{} + listenerStop chan struct{} + onClose func() + + mu sync.Mutex + } +) + +func newSession(clientAddr *net.UDPAddr, serverAddr string, serverConn net.Conn, + listenerStop chan struct{}, onClose func(), + clientIdleTimeout, serverIdleTimeout time.Duration) *session { + s := session{ + serverAddr: serverAddr, + clientAddr: clientAddr, + serverConn: serverConn, + serverIdleTimeout: serverIdleTimeout, + clientIdleTimeout: clientIdleTimeout, + writeBuf: make(chan *iobufferpool.Packet, 256), + + stopped: false, + stopChan: make(chan struct{}), + listenerStop: listenerStop, + onClose: onClose, + } + + go s.startSession(serverAddr, clientIdleTimeout) + return &s +} + +func (s *session) startSession(serverAddr string, clientIdleTimeout time.Duration) { + var t *time.Timer + var idleCheck <-chan time.Time + + if clientIdleTimeout > 0 { + t = time.NewTimer(clientIdleTimeout) + idleCheck = t.C + } + + for { + select { + case <-s.listenerStop: + s.close() + case <-idleCheck: + s.close() + case buf, ok := <-s.writeBuf: + if !ok { + s.close() + continue + } + + if t != nil { + if !t.Stop() { + <-t.C + } + t.Reset(clientIdleTimeout) + } + + bufLen := len(buf.Payload) + n, err := s.serverConn.Write(buf.Bytes()) + buf.Release() + + if err != nil { + logger.Errorf("udp connection flush data to server(%s) failed, err: %+v", serverAddr, err) + s.close() + continue + } + + if bufLen != n { + logger.Errorf("udp connection flush data to server(%s) failed, should write %d but written %d", + serverAddr, bufLen, n) + s.close() + } + case <-s.stopChan: + if t != nil { + t.Stop() + } + _ = s.serverConn.Close() + s.cleanWriteBuf() + s.onClose() + return + } + } +} + +// Write send data to buffer channel, wait flush to server +func (s *session) Write(buf *iobufferpool.Packet) error { + select { + case s.writeBuf <- buf: + return nil // try to send data with no check + default: + } + + var t *time.Timer + if s.serverIdleTimeout != 0 { + t = timerpool.Get(s.serverIdleTimeout * time.Millisecond) + } else { + t = timerpool.Get(60 * time.Second) + } + defer timerpool.Put(t) + + select { + case s.writeBuf <- buf: + return nil + case <-s.stopChan: + buf.Release() + return nil + case <-t.C: + buf.Release() + return fmt.Errorf("write data to channel timeout") + } +} + +// ListenResponse session listen server connection response and send to client +func (s *session) ListenResponse(sendTo *net.UDPConn) { + go func() { + buf := iobufferpool.UDPBufferPool.Get().([]byte) + defer s.close() + + for { + if s.serverIdleTimeout > 0 { + _ = s.serverConn.SetReadDeadline(fasttime.Now().Add(s.serverIdleTimeout)) + } + + n, err := s.serverConn.Read(buf) + if err != nil { + select { + case <-s.stopChan: + return // if session has closed, exit + default: + } + + if err, ok := err.(net.Error); ok && err.Timeout() { + continue + } + return + } + + nWrite, err := sendTo.WriteToUDP(buf[0:n], s.clientAddr) + if err != nil { + logger.Errorf("udp connection send data to client(%s) failed, err: %+v", s.clientAddr.String(), err) + return + } + + if n != nWrite { + logger.Errorf("udp connection send data to client(%s) failed, should write %d but written %d", + s.clientAddr.String(), n, nWrite) + return + } + } + }() +} + +func (s *session) cleanWriteBuf() { + for { + select { + case buf := <-s.writeBuf: + if buf != nil { + buf.Release() + } + default: + return + } + } +} + +// isClosed determine session if it is closed, used only for clean sessionMap +func (s *session) isClosed() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.stopped +} + +// close send session close signal +func (s *session) close() { + s.mu.Lock() + defer s.mu.Unlock() + if s.stopped == true { + return + } + + s.stopped = true + s.onClose() + close(s.stopChan) +} diff --git a/pkg/object/udpproxy/spec.go b/pkg/object/udpproxy/spec.go new file mode 100644 index 0000000000..0f6a3bb4a5 --- /dev/null +++ b/pkg/object/udpproxy/spec.go @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2017, MegaEase + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package udpproxy + +import ( + "github.com/megaease/easegress/pkg/util/ipfilter" + "github.com/megaease/easegress/pkg/util/layer4backend" +) + +type ( + + // Spec describes the udp server + Spec struct { + Name string `yaml:"name" json:"name" jsonschema:"required"` + Port uint16 `yaml:"port" json:"port" jsonschema:"required"` + + // HasResponse client udp request has response? + HasResponse bool `yaml:"hasResponse" jsonschema:"required"` + ClientIdleTimeout uint32 `yaml:"clientIdleTimeout" jsonschema:"omitempty,minimum=1"` + ServerIdleTimeout uint32 `yaml:"serverIdleTimeout" jsonschema:"omitempty,minimum=1"` + + Pool *layer4backend.Spec `yaml:"pool" jsonschema:"required"` + IPFilter *ipfilter.Spec `yaml:"ipFilters,omitempty" jsonschema:"omitempty"` + } +) + +// Validate validates Layer4 Server. +func (spec *Spec) Validate() error { + if poolErr := spec.Pool.Validate(); poolErr != nil { + return poolErr + } + + return nil +} diff --git a/pkg/object/udpproxy/udpserver.go b/pkg/object/udpproxy/udpserver.go new file mode 100644 index 0000000000..b5ff94d510 --- /dev/null +++ b/pkg/object/udpproxy/udpserver.go @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2017, MegaEase + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package udpproxy + +import ( + "net" + "sync" + + "github.com/megaease/easegress/pkg/supervisor" +) + +const ( + // Category is the category of TCPServer. + Category = supervisor.CategoryBusinessController + + // Kind is the kind of TCPServer. + Kind = "UDPServer" +) + +func init() { + supervisor.Register(&UDPServer{}) +} + +type ( + // UDPServer is Object of udp server. + UDPServer struct { + runtime *runtime + } + + connPool struct { + pool map[string]net.Conn + mu sync.RWMutex + } +) + +// Category get object category +func (u *UDPServer) Category() supervisor.ObjectCategory { + return Category +} + +// Kind get object kind +func (u *UDPServer) Kind() string { + return Kind +} + +// DefaultSpec get default spec of UDPServer +func (u *UDPServer) DefaultSpec() interface{} { + return &Spec{} +} + +// Status get UDPServer status +func (u *UDPServer) Status() *supervisor.Status { + return &supervisor.Status{} +} + +// Close actually close runtime +func (u *UDPServer) Close() { + u.runtime.close() +} + +// Init initializes UDPServer. +func (u *UDPServer) Init(superSpec *supervisor.Spec) { + u.runtime = newRuntime(superSpec) +} + +// Inherit inherits previous generation of UDPServer. +func (u *UDPServer) Inherit(superSpec *supervisor.Spec, previousGeneration supervisor.Object) { + + u.runtime = previousGeneration.(*UDPServer).runtime + u.runtime.close() + u.Init(superSpec) +} + +func newConnPool() *connPool { + return &connPool{ + pool: make(map[string]net.Conn), + } +} + +func (c *connPool) get(addr string) net.Conn { + if c == nil { + return nil + } + + c.mu.RLock() + defer c.mu.RUnlock() + return c.pool[addr] +} + +func (c *connPool) put(addr string, conn net.Conn) { + if c == nil { + return + } + + c.mu.Lock() + defer c.mu.Unlock() + c.pool[addr] = conn +} + +func (c *connPool) close() { + if c == nil { + return + } + + for _, conn := range c.pool { + _ = conn.Close() + } + c.pool = nil +} diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go index 8cabf6b560..dc0e3c5082 100644 --- a/pkg/registry/registry.go +++ b/pkg/registry/registry.go @@ -54,7 +54,9 @@ import ( _ "github.com/megaease/easegress/pkg/object/nacosserviceregistry" _ "github.com/megaease/easegress/pkg/object/pipeline" _ "github.com/megaease/easegress/pkg/object/rawconfigtrafficcontroller" + _ "github.com/megaease/easegress/pkg/object/tcpproxy" _ "github.com/megaease/easegress/pkg/object/trafficcontroller" + _ "github.com/megaease/easegress/pkg/object/udpproxy" _ "github.com/megaease/easegress/pkg/object/websocketserver" _ "github.com/megaease/easegress/pkg/object/zookeeperserviceregistry" ) diff --git a/pkg/supervisor/supervisor.go b/pkg/supervisor/supervisor.go index b697a05667..37ea8bb821 100644 --- a/pkg/supervisor/supervisor.go +++ b/pkg/supervisor/supervisor.go @@ -202,7 +202,7 @@ func (s *Supervisor) ObjectRegistry() *ObjectRegistry { return s.objectRegistry } -// WalkControllers walks every controllers until walkFn returns false. +// WalkControllers walks every controller until walkFn returns false. func (s *Supervisor) WalkControllers(walkFn WalkFunc) { defer func() { if err := recover(); err != nil { diff --git a/pkg/util/iobufferpool/constants.go b/pkg/util/iobufferpool/constants.go new file mode 100644 index 0000000000..99ae7e25c6 --- /dev/null +++ b/pkg/util/iobufferpool/constants.go @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2017, MegaEase + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package iobufferpool + +const ( + // UDPPacketMaxSize max size of udp packet + UDPPacketMaxSize = 65535 + // DefaultBufferReadCapacity default buffer capacity for stream proxy such as tcp + DefaultBufferReadCapacity = 1 << 16 +) diff --git a/pkg/util/iobufferpool/packet_pool.go b/pkg/util/iobufferpool/packet_pool.go new file mode 100644 index 0000000000..bc0b6daab5 --- /dev/null +++ b/pkg/util/iobufferpool/packet_pool.go @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2017, MegaEase + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package iobufferpool + +import ( + "sync" +) + +// UDPBufferPool udp buffer pool +var UDPBufferPool = sync.Pool{ + New: func() interface{} { + return make([]byte, UDPPacketMaxSize) + }, +} + +// Packet udp connection msg +type Packet struct { + Payload []byte + Len int +} + +// Bytes return underlying bytes for io buffer +func (p *Packet) Bytes() []byte { + if p.Payload == nil { + return nil + } + + return p.Payload[0:p.Len] +} + +// Release return io buffer resource to pool +func (p *Packet) Release() { + if p.Payload == nil { + return + } + UDPBufferPool.Put(p.Payload[:UDPPacketMaxSize]) +} diff --git a/pkg/util/iobufferpool/stream_buffer_pool.go b/pkg/util/iobufferpool/stream_buffer_pool.go new file mode 100644 index 0000000000..eec2229e2d --- /dev/null +++ b/pkg/util/iobufferpool/stream_buffer_pool.go @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2017, MegaEase + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package iobufferpool + +import ( + "github.com/valyala/bytebufferpool" +) + +// StreamBuffer io buffer for stream scene +type StreamBuffer struct { + payload *bytebufferpool.ByteBuffer + eof bool +} + +// NewStreamBuffer create stream buffer with specific payload +func NewStreamBuffer(buf []byte) *StreamBuffer { + res := &StreamBuffer{ + payload: bytebufferpool.Get(), + eof: false, + } + res.payload.Reset() + _, _ = res.payload.Write(buf) + return res +} + +// NewEOFStreamBuffer create stream buffer with eof sign +func NewEOFStreamBuffer() *StreamBuffer { + res := &StreamBuffer{ + payload: bytebufferpool.Get(), + eof: true, + } + res.payload.Reset() + return res +} + +// Bytes return underlying bytes +func (s *StreamBuffer) Bytes() []byte { + return s.payload.B +} + +// Len get buffer len +func (s *StreamBuffer) Len() int { + return len(s.payload.B) +} + +// Write implements io.Writer +func (s *StreamBuffer) Write(p []byte) (int, error) { + s.payload.B = append(s.payload.B, p...) + return len(p), nil +} + +// Release put buffer resource to pool +func (s *StreamBuffer) Release() { + if s.payload == nil { + return + } + s.payload.Reset() + bytebufferpool.Put(s.payload) +} + +// EOF return eof sign +func (s *StreamBuffer) EOF() bool { + return s.eof +} + +// SetEOF set eof sign +func (s *StreamBuffer) SetEOF(eof bool) { + s.eof = eof +} diff --git a/pkg/util/ipfilter/layer4ipfilters.go b/pkg/util/ipfilter/layer4ipfilters.go new file mode 100644 index 0000000000..7e90321f3c --- /dev/null +++ b/pkg/util/ipfilter/layer4ipfilters.go @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2017, MegaEase + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ipfilter + +import ( + "reflect" + "sync/atomic" +) + +type ( + // Layer4IpFilters layer4 ip filters + Layer4IpFilters struct { + rules atomic.Value + } + + ipFiltersRules struct { + spec *Spec + ipFilter *IPFilter + } +) + +// NewLayer4IPFilters create layer4 ip filters +func NewLayer4IPFilters(spec *Spec) *Layer4IpFilters { + m := &Layer4IpFilters{} + if spec == nil { + m.rules.Store(&ipFiltersRules{}) + } else { + m.rules.Store(&ipFiltersRules{ + spec: spec, + ipFilter: New(spec), + }) + } + return m +} + +// AllowIP check whether the IP is allowed to pass +func (i *Layer4IpFilters) AllowIP(ip string) bool { + rules := i.rules.Load().(*ipFiltersRules) + if rules.spec == nil { + return true + } + return rules.ipFilter.Allow(ip) +} + +// ReloadRules reload layer4 ip filters rules +func (i *Layer4IpFilters) ReloadRules(spec *Spec) { + if spec == nil { + i.rules.Store(&ipFiltersRules{}) + return + } + + old := i.rules.Load().(*ipFiltersRules) + if reflect.DeepEqual(old.spec, spec) { + return + } + + rules := &ipFiltersRules{ + spec: spec, + ipFilter: New(spec), + } + i.rules.Store(rules) +} + +func (r *ipFiltersRules) pass(clientIP string) bool { + if r.ipFilter == nil { + return true + } + return r.ipFilter.Allow(clientIP) +} diff --git a/pkg/util/layer4backend/backendserver.go b/pkg/util/layer4backend/backendserver.go new file mode 100644 index 0000000000..2f3dcfd1a9 --- /dev/null +++ b/pkg/util/layer4backend/backendserver.go @@ -0,0 +1,280 @@ +/* + * Copyright (c) 2017, MegaEase + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package layer4backend + +import ( + "fmt" + "math/rand" + "sync" + "sync/atomic" + + "github.com/megaease/easegress/pkg/logger" + "github.com/megaease/easegress/pkg/object/serviceregistry" + "github.com/megaease/easegress/pkg/supervisor" + "github.com/megaease/easegress/pkg/util/hashtool" + "github.com/megaease/easegress/pkg/util/stringtool" +) + +const ( + // PolicyRoundRobin is the policy of round-robin. + PolicyRoundRobin = "roundRobin" + // PolicyRandom is the policy of random. + PolicyRandom = "random" + // PolicyWeightedRandom is the policy of weighted random. + PolicyWeightedRandom = "weightedRandom" + // PolicyIPHash is the policy of ip hash. + PolicyIPHash = "ipHash" +) + +type ( + servers struct { + poolSpec *Spec + super *supervisor.Supervisor + + mutex sync.Mutex + serviceRegistry *serviceregistry.ServiceRegistry + serviceWatcher serviceregistry.ServiceWatcher + static *staticServers + + done chan struct{} + } + + staticServers struct { + count uint64 + weightsSum int + servers []*Server + lb LoadBalance + } + + // Server is proxy server. + Server struct { + Addr string `yaml:"url" jsonschema:"required,format=hostport"` + Tags []string `yaml:"tags" jsonschema:"omitempty,uniqueItems=true"` + Weight int `yaml:"weight" jsonschema:"omitempty,minimum=0,maximum=100"` + } + + // LoadBalance is load balance for multiple servers. + LoadBalance struct { + Policy string `yaml:"policy" jsonschema:"required,enum=roundRobin,enum=random,enum=weightedRandom,enum=ipHash"` + } +) + +func newServers(super *supervisor.Supervisor, poolSpec *Spec) *servers { + s := &servers{ + poolSpec: poolSpec, + super: super, + done: make(chan struct{}), + } + + s.useStaticServers() + if poolSpec.ServiceRegistry == "" || poolSpec.ServiceName == "" { + return s + } + + s.serviceRegistry = s.super.MustGetSystemController(serviceregistry.Kind). + Instance().(*serviceregistry.ServiceRegistry) + s.tryUseService() + s.serviceWatcher = s.serviceRegistry.NewServiceWatcher(s.poolSpec.ServiceRegistry, s.poolSpec.ServiceName) + + go s.watchService() + return s +} + +// String backend server info +func (s *Server) String() string { + return fmt.Sprintf("%s,%v,%d", s.Addr, s.Tags, s.Weight) +} + +func (s *servers) watchService() { + for { + select { + case <-s.done: + return + case event := <-s.serviceWatcher.Watch(): + s.handleEvent(event) + } + } +} + +func (s *servers) handleEvent(event *serviceregistry.ServiceEvent) { + s.useService(event.Instances) +} + +func (s *servers) tryUseService() { + serviceInstanceSpecs, err := s.serviceRegistry.ListServiceInstances(s.poolSpec.ServiceRegistry, s.poolSpec.ServiceName) + + if err != nil { + logger.Errorf("get service %s/%s failed: %v", + s.poolSpec.ServiceRegistry, s.poolSpec.ServiceName, err) + s.useStaticServers() + return + } + s.useService(serviceInstanceSpecs) +} + +func (s *servers) useService(serviceInstanceSpecs map[string]*serviceregistry.ServiceInstanceSpec) { + var servers []*Server + for _, instance := range serviceInstanceSpecs { + servers = append(servers, &Server{ + Addr: fmt.Sprintf("%s:%d", instance.Address, instance.Port), + Tags: instance.Tags, + Weight: instance.Weight, + }) + } + if len(servers) == 0 { + logger.Errorf("%s/%s: empty service instance", + s.poolSpec.ServiceRegistry, s.poolSpec.ServiceName) + s.useStaticServers() + return + } + + dynamicServers := newStaticServers(servers, s.poolSpec.ServersTags, s.poolSpec.LoadBalance) + if dynamicServers.len() == 0 { + logger.Errorf("%s/%s: no service instance satisfy tags: %v", + s.poolSpec.ServiceRegistry, s.poolSpec.ServiceName, s.poolSpec.ServersTags) + s.useStaticServers() + } + + logger.Infof("use dynamic service: %s/%s", s.poolSpec.ServiceRegistry, s.poolSpec.ServiceName) + + s.mutex.Lock() + defer s.mutex.Unlock() + s.static = dynamicServers +} + +func (s *servers) useStaticServers() { + s.mutex.Lock() + defer s.mutex.Unlock() + s.static = newStaticServers(s.poolSpec.Servers, s.poolSpec.ServersTags, s.poolSpec.LoadBalance) +} + +func (s *servers) snapshot() *staticServers { + s.mutex.Lock() + defer s.mutex.Unlock() + + return s.static +} + +func (s *servers) len() int { + static := s.snapshot() + return static.len() +} + +func (s *servers) next(cliAddr string) (*Server, error) { + static := s.snapshot() + if static.len() == 0 { + return nil, fmt.Errorf("no server available") + } + return static.next(cliAddr), nil +} + +func (s *servers) close() { + close(s.done) + + if s.serviceWatcher != nil { + s.serviceWatcher.Stop() + } +} + +func newStaticServers(servers []*Server, tags []string, lb *LoadBalance) *staticServers { + if servers == nil { + servers = make([]*Server, 0) + } + + ss := &staticServers{} + if lb == nil { + ss.lb.Policy = PolicyRoundRobin + } else { + ss.lb = *lb + } + + defer ss.prepare() + + if len(tags) == 0 { + ss.servers = servers + return ss + } + + chosenServers := make([]*Server, 0) + for _, server := range servers { + for _, tag := range tags { + if stringtool.StrInSlice(tag, server.Tags) { + chosenServers = append(chosenServers, server) + break + } + } + } + ss.servers = chosenServers + return ss +} + +func (ss *staticServers) prepare() { + for _, server := range ss.servers { + ss.weightsSum += server.Weight + } +} + +func (ss *staticServers) len() int { + return len(ss.servers) +} + +func (ss *staticServers) next(cliAddr string) *Server { + switch ss.lb.Policy { + case PolicyRoundRobin: + return ss.roundRobin() + case PolicyRandom: + return ss.random() + case PolicyWeightedRandom: + return ss.weightedRandom() + case PolicyIPHash: + return ss.ipHash(cliAddr) + } + logger.Errorf("BUG: unknown load balance policy: %s", ss.lb.Policy) + return ss.roundRobin() +} + +func (ss *staticServers) roundRobin() *Server { + count := atomic.AddUint64(&ss.count, 1) + // NOTE: startEventLoop from 0. + count-- + return ss.servers[int(count)%len(ss.servers)] +} + +func (ss *staticServers) random() *Server { + return ss.servers[rand.Intn(len(ss.servers))] +} + +func (ss *staticServers) weightedRandom() *Server { + randomWeight := rand.Intn(ss.weightsSum) + for _, server := range ss.servers { + randomWeight -= server.Weight + if randomWeight < 0 { + return server + } + } + + logger.Errorf("BUG: weighted random can't pick a server: sum(%d) servers(%+v)", + ss.weightsSum, ss.servers) + + return ss.random() +} + +func (ss *staticServers) ipHash(cliAddr string) *Server { + sum32 := int(hashtool.Hash32(cliAddr)) + return ss.servers[sum32%len(ss.servers)] +} diff --git a/pkg/util/layer4backend/pool.go b/pkg/util/layer4backend/pool.go new file mode 100644 index 0000000000..20d802e109 --- /dev/null +++ b/pkg/util/layer4backend/pool.go @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2017, MegaEase + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package layer4backend + +import ( + "reflect" + "sync/atomic" + + "github.com/megaease/easegress/pkg/supervisor" +) + +type ( + // Pool backend servers pool + Pool struct { + rules atomic.Value + } + + // pool backend server pool + poolRules struct { + spec *Spec + + tagPrefix string + servers *servers + } +) + +// NewPool create backend server pool +func NewPool(super *supervisor.Supervisor, spec *Spec, tagPrefix string) *Pool { + p := &Pool{} + + p.rules.Store(&poolRules{ + spec: spec, + + tagPrefix: tagPrefix, + servers: newServers(super, spec), + }) + return p +} + +// Next choose one backend for proxy +func (p *Pool) Next(cliAddr string) (*Server, error) { + rules := p.rules.Load().(*poolRules) + return rules.servers.next(cliAddr) +} + +// Close shutdown backend servers watcher +func (p *Pool) Close() { + rules := p.rules.Load().(*poolRules) + rules.servers.close() +} + +// ReloadRules reload backend servers pool rule +func (p *Pool) ReloadRules(super *supervisor.Supervisor, spec *Spec, tagPrefix string) { + old := p.rules.Load().(*poolRules) + if reflect.DeepEqual(old.spec, spec) { + return + } + + p.rules.Store(&poolRules{ + spec: spec, + + tagPrefix: tagPrefix, + servers: newServers(super, spec), + }) + p.Close() +} diff --git a/pkg/util/layer4backend/spec.go b/pkg/util/layer4backend/spec.go new file mode 100644 index 0000000000..0186b37f0e --- /dev/null +++ b/pkg/util/layer4backend/spec.go @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2017, MegaEase + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package layer4backend + +import "fmt" + +// Spec describes a pool of servers. +type Spec struct { + ServiceRegistry string `yaml:"serviceRegistry" jsonschema:"omitempty"` + ServiceName string `yaml:"serviceName" jsonschema:"omitempty"` + Servers []*Server `yaml:"servers" jsonschema:"omitempty"` + ServersTags []string `yaml:"serversTags" jsonschema:"omitempty,uniqueItems=true"` + LoadBalance *LoadBalance `yaml:"loadBalance" jsonschema:"required"` +} + +// Validate validates poolSpec. +func (s *Spec) Validate() error { + if s.ServiceName == "" && len(s.Servers) == 0 { + return fmt.Errorf("both serviceName and servers are empty") + } + + serversGotWeight := 0 + for _, server := range s.Servers { + if server.Weight > 0 { + serversGotWeight++ + } + } + if serversGotWeight > 0 && serversGotWeight < len(s.Servers) { + return fmt.Errorf("not all servers have weight(%d/%d)", + serversGotWeight, len(s.Servers)) + } + + if s.ServiceName == "" { + servers := newStaticServers(s.Servers, s.ServersTags, s.LoadBalance) + if servers.len() == 0 { + return fmt.Errorf("serversTags picks none of servers") + } + } + return nil +} diff --git a/pkg/util/limitlistener/limitlistener.go b/pkg/util/limitlistener/limitlistener.go index b91aad1c6e..e16e4885bd 100644 --- a/pkg/util/limitlistener/limitlistener.go +++ b/pkg/util/limitlistener/limitlistener.go @@ -47,8 +47,8 @@ type LimitListener struct { closeOnce sync.Once // ensures the done chan is only closed once } -// acquire acquires the limiting semaphore. Returns true if successfully -// accquired, false if the listener is closed and the semaphore is not +// acquire the limiting semaphore. Returns true if successfully +// acquired, false if the listener is closed and the semaphore is not // acquired. func (l *LimitListener) acquire() bool { return l.sem.AcquireWithContext(l.ctx) == nil @@ -73,7 +73,7 @@ func (l *LimitListener) Accept() (net.Conn, error) { l.release() return nil, err } - return &limitListenerConn{Conn: c, release: l.release}, nil + return &Conn{Conn: c, release: l.release}, nil } // SetMaxConnection sets max connection. @@ -88,13 +88,15 @@ func (l *LimitListener) Close() error { return err } -type limitListenerConn struct { +// Conn limit listener connection +type Conn struct { net.Conn releaseOnce sync.Once release func() } -func (l *limitListenerConn) Close() error { +// Close release semaphore and close connection +func (l *Conn) Close() error { err := l.Conn.Close() l.releaseOnce.Do(l.release) return err diff --git a/pkg/util/timerpool/timerpool.go b/pkg/util/timerpool/timerpool.go new file mode 100644 index 0000000000..fa01a5a6bf --- /dev/null +++ b/pkg/util/timerpool/timerpool.go @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2017, MegaEase + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package timerpool + +import ( + "sync" + "time" +) + +// copy from https://github.com/nats-io/nats.go/blob/main/timer.go + +// global pool of *time.Timer's. can be used by multiple goroutines concurrently. +var globalTimerPool timerPool + +// timerPool provides GC-able pooling of *time.Timer's. +// can be used by multiple goroutines concurrently. +type timerPool struct { + p sync.Pool +} + +// Get returns a timer that completes after the given duration. +func Get(d time.Duration) *time.Timer { + if t, _ := globalTimerPool.p.Get().(*time.Timer); t != nil { + t.Reset(d) + return t + } + + return time.NewTimer(d) +} + +// Put pools the given timer. +// +// There is no need to call t.stop() before calling Put. +// +// Put will try to stop the timer before pooling. If the +// given timer already expired, Put will read the unreceived +// value if there is one. +func Put(t *time.Timer) { + if !t.Stop() { + select { + case <-t.C: + default: + } + } + + globalTimerPool.p.Put(t) +}