diff --git a/netfilter.go b/netfilter.go index ee8de85..4cc029a 100644 --- a/netfilter.go +++ b/netfilter.go @@ -4,93 +4,92 @@ import ( "log" "os" "os/signal" - "sync" "syscall" - netfilter "github.com/AkihiroSuda/go-netfilter-queue" "github.com/google/gopacket" "github.com/google/gopacket/layers" + netfilter "github.com/AkihiroSuda/go-netfilter-queue" ) /***Variables***/ -type NetfilterQueue struct { +type NetFilterQueue struct { // Set Variables - Handler func(*RBKV, *PacketData) netfilter.Verdict - QueueNum uint16 - MaxWorkers int + Handler func(*log.Logger, *RBKV, *PacketData) netfilter.Verdict + QueueNum uint16 + LogAllErrors bool + Logger *log.Logger // queue handler objects nfq *netfilter.NFQueue pktQueue <-chan netfilter.NFPacket - - // worker/class handler objects - started bool - wg sync.WaitGroup + wp *workerPool } /***Methods***/ -//(*NetfilterQueue).Start : spawn nfq instance and start collecting packets -func (queue *NetfilterQueue) Start() { +//(*NetFilterQueue).start : spawn nfq instance and start collecting packets +func (q *NetFilterQueue) start() { // check if already started - if queue.started { - log.Fatalf("NFQueue %d ALREADY STARTED!\n", queue.QueueNum) + if q.wp != nil { + q.Logger.Fatalf("NFQueue %d ALREADY STARTED!\n", q.QueueNum) } // spawn netfilter queue instance and start collecting packets var err error - queue.nfq, err = netfilter.NewNFQueue(queue.QueueNum, 100, netfilter.NF_DEFAULT_PACKET_SIZE) + q.nfq, err = netfilter.NewNFQueue(q.QueueNum, 100, netfilter.NF_DEFAULT_PACKET_SIZE) if err != nil { - log.Fatalf("NFQueue %d Error: %s\n", queue.QueueNum, err.Error()) + log.Fatalf("NFQueue %d Error: %s\n", q.QueueNum, err.Error()) } - log.Printf("NFQueue: %d Initialized! Starting Workers...\n", queue.QueueNum) + log.Printf("NFQueue: %d Initialized! Starting WorkerPool...\n", q.QueueNum) // set packet queue and started boolean - queue.pktQueue = queue.nfq.GetPackets() - queue.started = true - // start max number of workers - for i := 0; i < queue.MaxWorkers; i++ { - go queue.worker() - queue.wg.Add(1) + q.pktQueue = q.nfq.GetPackets() + // spawn workerpool + q.wp = &workerPool{ + WorkerFunc: q.handlePacket, + MaxWorkersCount: 10 * 1024, + LogAllErrors: q.LogAllErrors, + Logger: q.Logger, } - log.Println("Workers Started!") -} - -//(*NetfilterQueue).Wait : wait for threads to finish FOREVER!!! (A really long time) -func (queue *NetfilterQueue) Wait() { - queue.wg.Wait() + q.wp.Start() } -//(*NetfilterQueue).Stop : close nfq instance and stop collecting packets -func (queue *NetfilterQueue) Stop() { +//(*NetFilterQueue).stop : close nfq instance and stop collecting packets +func (q *NetFilterQueue) stop() { // check if not started - if !queue.started { - log.Fatalf("NFQueue %d NEVER STARTED!\n", queue.QueueNum) + if q.wp == nil { + log.Fatalf("NFQueue %d NEVER STARTED!\n", q.QueueNum) } - // close queue instance - queue.nfq.Close() - // close packet queue and set started boolean - queue.pktQueue = nil - queue.started = false + // close/stop everything + q.nfq.Close() + q.wp.Stop() + q.pktQueue = nil } -func (queue *NetfilterQueue) Run() { +//(*NetFilterQueue).Run : run nfq indefinably and block until interrupt +func (q *NetFilterQueue) Run() { // start netfilter queue instance - queue.Start() - // handle interupts + q.start() + // handle interrupts c := make(chan os.Signal, 2) signal.Notify(c, os.Interrupt, syscall.SIGTERM) go func() { for sig := range c { - log.Fatalf("Captured Signal: %v! Cleaning up...", sig) - queue.Stop() + log.Fatalf("Captured Signal: %s! Cleaning up...", sig.String()) + q.stop() } }() - // wait possibly forever - queue.Wait() + // handle incoming packets + var p netfilter.NFPacket + for { + p = <- q.pktQueue + if !q.wp.Serve(p) { + log.Println("worker error! serving connection failed!") + } + } } -//(*NetfilterQueue).parsePacket : parse gopacket and return collected packet data -func (queue *NetfilterQueue) parsePacket(packetin gopacket.Packet, packetout *PacketData) { +//(*NetFilterQueue).parsePacket : parse gopacket and return collected packet data +func (q *NetFilterQueue) parsePacket(packetin gopacket.Packet, packetout *PacketData) { //get src and dst ip from ipv4 ipLayer := packetin.Layer(layers.LayerTypeIPv4) if ipLayer != nil { @@ -108,25 +107,18 @@ func (queue *NetfilterQueue) parsePacket(packetin gopacket.Packet, packetout *Pa } } -//(*NetfilterQueue).worker : worker instance used to set the verdict for queued packets -func (queue *NetfilterQueue) worker() { - // defer waitgroup completion - defer queue.wg.Done() +//(*NetFilterQueue).worker : worker instance used to set the verdict for queued packets +func (q *NetFilterQueue) handlePacket(p netfilter.NFPacket) error { // init variables for packet handling var ( - nfqPacket netfilter.NFPacket //Reused netfilter packet object - dataPacket PacketData //Reused parsed packet data as struct - redblackkv *RBKV = NewRedBlackKV() //Reused key/value pair for red black tree caches + dataPacket PacketData //Reused parsed packet data as struct + redBlackKV = &RBKV{} //Reused key/value pair for red black tree caches ) - // loop while running forever - for queue.started { - // collect verdict packet from netfilerqueu - nfqPacket = <-queue.pktQueue - // parse packet for required information - queue.parsePacket(nfqPacket.Packet, &dataPacket) - // complete logic go get verfict on packet and set verdict - nfqPacket.SetVerdict( - queue.Handler(redblackkv, &dataPacket), - ) - } + // parse packet for required information + q.parsePacket(p.Packet, &dataPacket) + // complete logic go get verdict on packet and set verdict + p.SetVerdict( + q.Handler(q.Logger, redBlackKV, &dataPacket), + ) + return nil } diff --git a/workpool.go b/workpool.go new file mode 100644 index 0000000..29494d7 --- /dev/null +++ b/workpool.go @@ -0,0 +1,260 @@ +package goaway2 + +import ( + "log" + "time" + "sync" + "runtime" + "sync/atomic" + + "github.com/AkihiroSuda/go-netfilter-queue" +) + +//stolen from: https://github.com/valyala/fasthttp/blob/master/workerpool.go +//stolen from: https://github.com/valyala/fasthttp/blob/master/coarseTime.go +// this system uses a slightly tweaked version of the workerpool from fasthttp to handle and process +// incoming packets from NetFilterQueue as fast as possible + +/* Variables */ + +//timeStore : temporary store for time.Time in truncated form to allow for fast access / usage +var timeStore atomic.Value + +// workerPool serves incoming connections via a pool of workers +// in FILO order, i.e. the most recently stopped worker will serve the next +// incoming connection. +// Such a scheme keeps CPU caches hot (in theory). +type workerPool struct { + // Function for serving server connections. + // It must leave c unclosed. + WorkerFunc func(p netfilter.NFPacket) error + MaxWorkersCount int + LogAllErrors bool + MaxIdleWorkerDuration time.Duration + Logger *log.Logger + + lock sync.Mutex + workersCount int + mustStop bool + ready []*workerChan + stopCh chan struct{} + workerChanPool sync.Pool +} + +//workerChan : contains channel to handle given packets along with expiration timer +type workerChan struct { + lastUseTime time.Time + ch chan netfilter.NFPacket +} + +var workerChanCap = func() int { + // Use blocking workerChan if GOMAXPROCS=1. + // This immediately switches Serve to WorkerFunc, which results + // in higher performance (under go1.5 at least). + if runtime.GOMAXPROCS(0) == 1 { + return 0 + } + + // Use non-blocking workerChan if GOMAXPROCS>1, + // since otherwise the Serve caller (Acceptor) may lag accepting + // new connections if WorkerFunc is CPU-bound. + return 1 +}() + +/* Functions */ + +//CoarseTimeNow : return time truncated to seconds which +// is faster than using non-truncated version +func CoarseTimeNow() time.Time { + tp := timeStore.Load().(*time.Time) + return *tp +} + +/* Init */ +func init() { + t := time.Now().Truncate(time.Second) + timeStore.Store(&t) + go func() { + for { + time.Sleep(time.Second) + t := time.Now().Truncate(time.Second) + timeStore.Store(&t) + } + }() +} + +/* Methods */ + +//(*workerPool).Start : start worker-pool +func (wp *workerPool) Start() { + if wp.stopCh != nil { + panic("BUG: workerPool already started") + } + wp.stopCh = make(chan struct{}) + stopCh := wp.stopCh + go func() { + var scratch []*workerChan + for { + wp.clean(&scratch) + select { + case <-stopCh: + return + default: + time.Sleep(wp.getMaxIdleWorkerDuration()) + } + } + }() +} + +//(*workerPool).Stop : stop worker-pool +func (wp *workerPool) Stop() { + if wp.stopCh == nil { + panic("BUG: workerPool wasn't started") + } + close(wp.stopCh) + wp.stopCh = nil + + // Stop all the workers waiting for incoming connections. + // Do not wait for busy workers - they will stop after + // serving the connection and noticing wp.mustStop = true. + wp.lock.Lock() + ready := wp.ready + wp.Logger.Printf("DBUG: stopping all workers!") + for i, ch := range ready { + ch.ch <- netfilter.NFPacket{Packet: nil} + ready[i] = nil + } + wp.ready = ready[:0] + wp.mustStop = true + wp.lock.Unlock() +} + +//(*workerPool).getMaxIdleWorkerDuration : return variable with exception +func (wp *workerPool) getMaxIdleWorkerDuration() time.Duration { + if wp.MaxIdleWorkerDuration <= 0 { + return 10 * time.Second + } + return wp.MaxIdleWorkerDuration +} + +//(*workerPool).clean : remove inactive workers +func (wp *workerPool) clean(scratch *[]*workerChan) { + maxIdleWorkerDuration := wp.getMaxIdleWorkerDuration() + + // Clean least recently used workers if they didn't serve connections + // for more than maxIdleWorkerDuration. + currentTime := time.Now() + + wp.lock.Lock() + ready := wp.ready + n := len(ready) + i := 0 + for i < n && currentTime.Sub(ready[i].lastUseTime) > maxIdleWorkerDuration { + i++ + } + *scratch = append((*scratch)[:0], ready[:i]...) + if i > 0 { + m := copy(ready, ready[i:]) + for i = m; i < n; i++ { + ready[i] = nil + } + wp.ready = ready[:m] + } + wp.lock.Unlock() + + // Notify obsolete workers to stop. + // This notification must be outside the wp.lock, since ch.ch + // may be blocking and may consume a lot of time if many workers + // are located on non-local CPUs. + tmp := *scratch + for i, ch := range tmp { + ch.ch <- netfilter.NFPacket{Packet: nil} + tmp[i] = nil + wp.Logger.Printf("DBUG: attempting to clean worker!") + } +} + +//(*workerPool).Serve : pass connection to workerPool to handle +func (wp *workerPool) Serve(p netfilter.NFPacket) bool { + ch := wp.getCh() + if ch == nil { + return false + } + ch.ch <- p + return true +} + +//(*workerPool).getCh : return available channel to pass packet for worker pool to handle +func (wp *workerPool) getCh() *workerChan { + var ch *workerChan + createWorker := false + + wp.lock.Lock() + ready := wp.ready + n := len(ready) - 1 + if n < 0 { + if wp.workersCount < wp.MaxWorkersCount { + createWorker = true + wp.workersCount++ + } + } else { + ch = ready[n] + ready[n] = nil + wp.ready = ready[:n] + } + wp.lock.Unlock() + + if ch == nil { + if !createWorker { + return nil + } + vch := wp.workerChanPool.Get() + if vch == nil { + vch = &workerChan{ + ch: make(chan netfilter.NFPacket, workerChanCap), + } + } + ch = vch.(*workerChan) + go func() { + wp.workerFunc(ch) + wp.workerChanPool.Put(vch) + }() + } + return ch +} + +//(*workerPool).release : allow channel to be used among another worker +func (wp *workerPool) release(ch *workerChan) bool { + ch.lastUseTime = CoarseTimeNow() + wp.lock.Lock() + if wp.mustStop { + wp.lock.Unlock() + return false + } + wp.ready = append(wp.ready, ch) + wp.lock.Unlock() + return true +} + +//(*workerPool).workerFunc : worker function used to handle incoming connections via channels +func (wp *workerPool) workerFunc(ch *workerChan) { + var p netfilter.NFPacket + var err error + for p = range ch.ch { + if p.Packet == nil { + break + } + if err = wp.WorkerFunc(p); err != nil { + if wp.LogAllErrors { + wp.Logger.Printf("error when handling packet: %s", err) + } + } + if !wp.release(ch) { + break + } + } + wp.lock.Lock() + wp.workersCount-- + wp.lock.Unlock() + wp.Logger.Printf("DBUG: Worker Exited!") +} \ No newline at end of file