Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix memory bloat for unresponded to packets #33

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions IDGenerator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package probing

import (
"math"

"github.com/google/uuid"
)

type sequenceBucket struct {
bitmap []bool
usedCount int64
}

type idGenerator struct {
awaitingSequences map[uuid.UUID]*sequenceBucket
currentCount int
currentUUID uuid.UUID
oldestUUIDQueue []uuid.UUID
totalOutstandingSequences int64
deletionPolicyFn func(*idGenerator)
}

func newIDGenerator(deletionPolicy func(*idGenerator)) *idGenerator {
return &idGenerator{
awaitingSequences: make(map[uuid.UUID]*sequenceBucket),
deletionPolicyFn: deletionPolicy,
}
}

func (g *idGenerator) next() (uuid.UUID, int) {
if g.currentCount == 0 {
g.currentUUID = uuid.New()
}

var ok bool
if _, ok = g.awaitingSequences[g.currentUUID]; !ok {
inFlightSeqs := new(sequenceBucket)
inFlightSeqs.bitmap = make([]bool, math.MaxUint16)
g.awaitingSequences[g.currentUUID] = inFlightSeqs
g.oldestUUIDQueue = append(g.oldestUUIDQueue, g.currentUUID)

}

nextSeq := g.currentCount
nextUUID := g.currentUUID
g.totalOutstandingSequences++

g.awaitingSequences[g.currentUUID].bitmap[g.currentCount] = true
g.awaitingSequences[g.currentUUID].usedCount++

// Run the deletion policy. May modify awaitingSequences
g.deletionPolicyFn(g)

g.currentCount = (g.currentCount + 1) % math.MaxUint16
return nextUUID, nextSeq
}

func (g *idGenerator) find(UUID uuid.UUID, sequenceNum int) bool {
if val, ok := g.awaitingSequences[UUID]; ok {
return val.bitmap[sequenceNum]
}
return false
}

func (g *idGenerator) retire(UUID uuid.UUID, sequenceNum int) {
if val, ok := g.awaitingSequences[UUID]; ok {
val.bitmap[sequenceNum] = false
val.usedCount--
}
g.totalOutstandingSequences--
}

func (g *idGenerator) retireBucket(UUID uuid.UUID) {
if val, ok := g.awaitingSequences[UUID]; ok {
g.totalOutstandingSequences -= val.usedCount
delete(g.awaitingSequences, UUID)
}
}

func RemoveAfterMaxItems(maxItems int64) func(*idGenerator) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Public functions should have a doc comment.

Suggested change
func RemoveAfterMaxItems(maxItems int64) func(*idGenerator) {
// RemoveAfterMaxItems does some thing, I don't know what.
func RemoveAfterMaxItems(maxItems int64) func(*idGenerator) {

return func(g *idGenerator) {
if g.totalOutstandingSequences == maxItems {
toDeleteUUID := g.oldestUUIDQueue[0]
g.oldestUUIDQueue = g.oldestUUIDQueue[1:]
g.retireBucket(toDeleteUUID)
}
}
}
43 changes: 15 additions & 28 deletions ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,7 @@ var (
// New returns a new Pinger struct pointer.
func New(addr string) *Pinger {
r := rand.New(rand.NewSource(getSeed()))
firstUUID := uuid.New()
var firstSequence = map[uuid.UUID]map[int]struct{}{}
firstSequence[firstUUID] = make(map[int]struct{})
defaultDeletionPolicy := func(*idGenerator) {}
return &Pinger{
Count: -1,
Interval: time.Second,
Expand All @@ -104,12 +102,11 @@ func New(addr string) *Pinger {
addr: addr,
done: make(chan interface{}),
id: r.Intn(math.MaxUint16),
trackerUUIDs: []uuid.UUID{firstUUID},
ipaddr: nil,
ipv4: false,
network: "ip",
protocol: "udp",
awaitingSequences: firstSequence,
awaitingSequences: newIDGenerator(defaultDeletionPolicy),
TTL: 64,
logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())},
}
Expand Down Expand Up @@ -205,14 +202,11 @@ type Pinger struct {
// df when true sets the do-not-fragment bit in the outer IP or IPv6 header
df bool

// trackerUUIDs is the list of UUIDs being used for sending packets.
trackerUUIDs []uuid.UUID

ipv4 bool
id int
sequence int
// awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts
awaitingSequences map[uuid.UUID]map[int]struct{}
awaitingSequences *idGenerator
// network is one of "ip", "ip4", or "ip6".
network string
// protocol is "icmp" or "udp".
Expand Down Expand Up @@ -686,18 +680,12 @@ func (p *Pinger) getPacketUUID(pkt []byte) (*uuid.UUID, error) {
if err != nil {
return nil, fmt.Errorf("error decoding tracking UUID: %w", err)
}

for _, item := range p.trackerUUIDs {
if item == packetUUID {
return &packetUUID, nil
}
}
return nil, nil
return &packetUUID, nil
}

// getCurrentTrackerUUID grabs the latest tracker UUID.
func (p *Pinger) getCurrentTrackerUUID() uuid.UUID {
return p.trackerUUIDs[len(p.trackerUUIDs)-1]
return p.awaitingSequences.currentUUID
}

func (p *Pinger) processPacket(recv *packet) error {
Expand Down Expand Up @@ -750,15 +738,16 @@ func (p *Pinger) processPacket(recv *packet) error {
inPkt.Rtt = receivedAt.Sub(timestamp)
inPkt.Seq = pkt.Seq
// If we've already received this sequence, ignore it.
if _, inflight := p.awaitingSequences[*pktUUID][pkt.Seq]; !inflight {
if !p.awaitingSequences.find(*pktUUID, pkt.Seq) {
p.PacketsRecvDuplicates++
if p.OnDuplicateRecv != nil {
p.OnDuplicateRecv(inPkt)
}
return nil
}

// remove it from the list of sequences we're waiting for so we don't get duplicates.
delete(p.awaitingSequences[*pktUUID], pkt.Seq)
p.awaitingSequences.retire(*pktUUID, pkt.Seq)
p.updateStatistics(inPkt)
default:
// Very bad, not sure how this can happen
Expand All @@ -778,7 +767,8 @@ func (p *Pinger) sendICMP(conn packetConn) error {
dst = &net.UDPAddr{IP: p.ipaddr.IP, Zone: p.ipaddr.Zone}
}

currentUUID := p.getCurrentTrackerUUID()
currentUUID, sequenceNumber := p.awaitingSequences.next()
p.sequence = sequenceNumber
uuidEncoded, err := currentUUID.MarshalBinary()
if err != nil {
return fmt.Errorf("unable to marshal UUID binary: %w", err)
Expand Down Expand Up @@ -835,15 +825,7 @@ func (p *Pinger) sendICMP(conn packetConn) error {
p.OnSend(outPkt)
}
// mark this sequence as in-flight
p.awaitingSequences[currentUUID][p.sequence] = struct{}{}
p.PacketsSent++
p.sequence++
if p.sequence > 65535 {
newUUID := uuid.New()
p.trackerUUIDs = append(p.trackerUUIDs, newUUID)
p.awaitingSequences[newUUID] = make(map[int]struct{})
p.sequence = 0
}
break
}

Expand Down Expand Up @@ -873,6 +855,11 @@ func (p *Pinger) listen() (packetConn, error) {
return conn, nil
}

// SetOutstandingPacketsPolicy Sets the deletion policy for packets which have not received an ICMP response
func (p *Pinger) SetOutstandingPacketsPolicy(deletionPolicy func(*idGenerator)) {
p.awaitingSequences = newIDGenerator(deletionPolicy)
}

func bytesToTime(b []byte) time.Time {
var nsec int64
for i := uint8(0); i < 8; i++ {
Expand Down
30 changes: 23 additions & 7 deletions ping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestProcessPacket(t *testing.T) {
shouldBe1++
}

currentUUID := pinger.getCurrentTrackerUUID()
currentUUID, sequenceNum := pinger.awaitingSequences.next()
uuidEncoded, err := currentUUID.MarshalBinary()
if err != nil {
t.Fatalf("unable to marshal UUID binary: %s", err)
Expand All @@ -36,10 +36,9 @@ func TestProcessPacket(t *testing.T) {

body := &icmp.Echo{
ID: pinger.id,
Seq: pinger.sequence,
Seq: sequenceNum,
Data: data,
}
pinger.awaitingSequences[currentUUID][pinger.sequence] = struct{}{}

msg := &icmp.Message{
Type: ipv4.ICMPTypeEchoReply,
Expand Down Expand Up @@ -609,7 +608,8 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) {
dups++
}

currentUUID := pinger.getCurrentTrackerUUID()
// register the sequence as sent
currentUUID, sequenceNum := pinger.awaitingSequences.next()
uuidEncoded, err := currentUUID.MarshalBinary()
if err != nil {
t.Fatalf("unable to marshal UUID binary: %s", err)
Expand All @@ -621,11 +621,9 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) {

body := &icmp.Echo{
ID: 123,
Seq: 0,
Seq: sequenceNum,
Data: data,
}
// register the sequence as sent
pinger.awaitingSequences[currentUUID][0] = struct{}{}

msg := &icmp.Message{
Type: ipv4.ICMPTypeEchoReply,
Expand Down Expand Up @@ -833,3 +831,21 @@ func TestRunWithBackgroundContext(t *testing.T) {
}
AssertTrue(t, stats.PacketsRecv == 10)
}

func TestBlackhole(t *testing.T) {
pinger := New("127.0.0.2")
pinger.Count = 99
pinger.Interval = time.Microsecond
pinger.Timeout = time.Second
pinger.SetOutstandingPacketsPolicy(RemoveAfterMaxItems(10))

err := pinger.Resolve()
AssertNoError(t, err)

conn := new(testPacketConn)

err = pinger.run(context.Background(), conn)
AssertNoError(t, err)

AssertTrue(t, pinger.awaitingSequences.totalOutstandingSequences == 9)
}