diff --git a/go.mod b/go.mod index 48f1a76e..5fdd9884 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.20 require ( github.com/dsnet/golib/memfile v1.0.0 + github.com/karlseguin/ccache/v3 v3.0.6 github.com/pion/dtls/v3 v3.0.2 github.com/stretchr/testify v1.9.0 go.uber.org/atomic v1.11.0 diff --git a/go.sum b/go.sum index 5c8dab3c..7cdafeb8 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dsnet/golib/memfile v1.0.0 h1:J9pUspY2bDCbF9o+YGwcf3uG6MdyITfh/Fk3/CaEiFs= github.com/dsnet/golib/memfile v1.0.0/go.mod h1:tXGNW9q3RwvWt1VV2qrRKlSSz0npnh12yftCSCy2T64= +github.com/karlseguin/ccache/v3 v3.0.6 h1:6wC04CXSdptebuSUBgsQixNrrRMUdimtwmjlJUpCf/4= +github.com/karlseguin/ccache/v3 v3.0.6/go.mod h1:b0qfdUOHl4vJgKFQN41paXIdBb3acAtyX2uWrBAZs1w= github.com/pion/dtls/v3 v3.0.2 h1:425DEeJ/jfuTTghhUDW0GtYZYIwwMtnKKJNMcWccTX0= github.com/pion/dtls/v3 v3.0.2/go.mod h1:dfIXcFkKoujDQ+jtd8M6RgqKK3DuaUilm3YatAbGp5k= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= diff --git a/udp/client/conn.go b/udp/client/conn.go index 54d215f3..59d09fa1 100644 --- a/udp/client/conn.go +++ b/udp/client/conn.go @@ -6,9 +6,11 @@ import ( "fmt" "math" "net" + "strconv" "sync" "time" + "github.com/karlseguin/ccache/v3" "github.com/plgd-dev/go-coap/v3/message" "github.com/plgd-dev/go-coap/v3/message/codes" "github.com/plgd-dev/go-coap/v3/message/pool" @@ -20,7 +22,6 @@ import ( "github.com/plgd-dev/go-coap/v3/net/observation" "github.com/plgd-dev/go-coap/v3/net/responsewriter" "github.com/plgd-dev/go-coap/v3/options/config" - "github.com/plgd-dev/go-coap/v3/pkg/cache" coapErrors "github.com/plgd-dev/go-coap/v3/pkg/errors" "github.com/plgd-dev/go-coap/v3/pkg/fn" pkgMath "github.com/plgd-dev/go-coap/v3/pkg/math" @@ -127,6 +128,57 @@ func (m *midElement) GetMessage(cc *Conn) (*pool.Message, bool, error) { return msg, true, nil } +// MessageCache is a cache of CoAP messages. +type MessageCache interface { + Load(key string, msg *pool.Message) (bool, error) + Store(key string, msg *pool.Message) error + Close() +} + +// messageCache is a CoAP message cache backed by +type messageCache struct { + c *ccache.Cache[*pool.Message] + pool *pool.Pool +} + +// newMessageCache constructs a new CoAP message cache. +func newMessageCache() *messageCache { + p := pool.New(30, 1600) + return &messageCache{ + c: ccache.New(ccache.Configure[*pool.Message]().MaxSize(30).OnDelete(func(item *ccache.Item[*pool.Message]) { + p.ReleaseMessage(item.Value()) + })), + pool: p, + } +} + +// Load loads a message from the cache if one exists with key. +func (m *messageCache) Load(key string, msg *pool.Message) (bool, error) { + item := m.c.Get(key) + if item == nil || item.Expired() { + return false, nil + } + if err := item.Value().Clone(msg); err != nil { + return false, err + } + return true, nil +} + +// Store stores a message in the cache. +func (m *messageCache) Store(key string, msg *pool.Message) error { + cached := m.pool.AcquireMessage(context.Background()) + if err := msg.Clone(cached); err != nil { + return err + } + m.c.Set(key, cached, ExchangeLifetime) + return nil +} + +// Close closes the cache. +func (m *messageCache) Close() { + m.c.Stop() +} + // Conn represents a virtual connection to a conceptual endpoint, to perform COAPs commands. type Conn struct { // This field needs to be the first in the struct to ensure proper word alignment on 32-bit platforms. @@ -145,7 +197,7 @@ type Conn struct { processReceivedMessage config.ProcessReceivedMessageFunc[*Conn] errors ErrorFunc - responseMsgCache *cache.Cache[string, []byte] + responseMsgCache MessageCache msgIDMutex *MutexMap tokenHandlerContainer *coapSync.Map[uint64, HandlerFunc] @@ -192,6 +244,7 @@ type ConnOptions struct { createBlockWise func(cc *Conn) *blockwise.BlockWise[*Conn] inactivityMonitor InactivityMonitor requestMonitor RequestMonitorFunc + responseMsgCache MessageCache } type Option = func(opts *ConnOptions) @@ -220,6 +273,23 @@ func WithRequestMonitor(requestMonitor RequestMonitorFunc) Option { } } +// WithResponseMessageCache sets the cache used for response messages. All +// response messages are submitted to the cache, but it is up to the cache +// implementation to determine which messages are stored and for how long. +// Caching responses enables sending the same Acknowledgment for retransmitted +// confirmable messages within an EXCHANGE_LIFETIME. It may be desirable to +// relax this behavior in some scenarios. +// https://datatracker.ietf.org/doc/html/rfc7252#section-4.5 +// The default response message cache uses an LRU cache with capacity of 30 +// items and expiration of 247 seconds, which is EXCHANGE_LIFETIME when using +// default CoAP transmission parameters. +// https://datatracker.ietf.org/doc/html/rfc7252#section-4.8.2 +func WithResponseMessageCache(cache MessageCache) Option { + return func(opts *ConnOptions) { + opts.responseMsgCache = cache + } +} + func NewConnWithOpts(session Session, cfg *Config, opts ...Option) *Conn { if cfg.Errors == nil { cfg.Errors = func(error) { @@ -248,6 +318,10 @@ func NewConnWithOpts(session Session, cfg *Config, opts ...Option) *Conn { for _, o := range opts { o(&cfgOpts) } + // Only construct cache if one was not set via options. + if cfgOpts.responseMsgCache == nil { + cfgOpts.responseMsgCache = newMessageCache() + } cc := Conn{ session: session, transmission: &Transmission{ @@ -262,7 +336,7 @@ func NewConnWithOpts(session Session, cfg *Config, opts ...Option) *Conn { processReceivedMessage: cfg.ProcessReceivedMessage, errors: cfg.Errors, msgIDMutex: NewMutexMap(), - responseMsgCache: cache.NewCache[string, []byte](), + responseMsgCache: cfgOpts.responseMsgCache, inactivityMonitor: cfgOpts.inactivityMonitor, requestMonitor: cfgOpts.requestMonitor, messagePool: cfg.MessagePool, @@ -318,6 +392,7 @@ func (cc *Conn) Close() error { if errors.Is(err, net.ErrClosed) { return nil } + cc.responseMsgCache.Close() return err } @@ -609,34 +684,14 @@ func (cc *Conn) Sequence() uint64 { return cc.sequence.Add(1) } -func (cc *Conn) responseMsgCacheID(msgID int32) string { - return fmt.Sprintf("resp-%v-%d", cc.RemoteAddr(), msgID) +// getResponseFromCache gets a message from the response message cache. +func (cc *Conn) getResponseFromCache(mid int32, resp *pool.Message) (bool, error) { + return cc.responseMsgCache.Load(strconv.Itoa(int(mid)), resp) } +// addResponseToCache adds a message to the response message cache. func (cc *Conn) addResponseToCache(resp *pool.Message) error { - marshaledResp, err := resp.MarshalWithEncoder(coder.DefaultCoder) - if err != nil { - return err - } - cacheMsg := make([]byte, len(marshaledResp)) - copy(cacheMsg, marshaledResp) - cc.responseMsgCache.LoadOrStore(cc.responseMsgCacheID(resp.MessageID()), cache.NewElement(cacheMsg, time.Now().Add(ExchangeLifetime), nil)) - return nil -} - -func (cc *Conn) getResponseFromCache(mid int32, resp *pool.Message) (bool, error) { - cachedResp := cc.responseMsgCache.Load(cc.responseMsgCacheID(mid)) - if cachedResp == nil { - return false, nil - } - if rawMsg := cachedResp.Data(); len(rawMsg) > 0 { - _, err := resp.UnmarshalWithDecoder(coder.DefaultCoder, rawMsg) - if err != nil { - return false, err - } - return true, nil - } - return false, nil + return cc.responseMsgCache.Store(strconv.Itoa(int(resp.MessageID())), resp) } // checkMyMessageID compare client msgID against peer messageID and if it is near < 0xffff/4 then increase msgID. @@ -907,7 +962,6 @@ func (cc *Conn) checkMidHandlerContainer(now time.Time, maxRetransmit uint32, ac // CheckExpirations checks and remove expired items from caches. func (cc *Conn) CheckExpirations(now time.Time) { cc.inactivityMonitor.CheckInactivity(now, cc) - cc.responseMsgCache.CheckExpirations(now) if cc.blockWise != nil { cc.blockWise.CheckExpirations(now) }