From a6c664a05e28bedab17430d90521df12b31b6c46 Mon Sep 17 00:00:00 2001 From: Louis Royer Date: Thu, 20 Feb 2025 12:14:55 +0100 Subject: [PATCH] Refactor: WithContext type --- internal/amf/amf.go | 17 ++++--------- internal/{amf => common}/errors.go | 4 +-- internal/common/with-context.go | 30 ++++++++++++++++++++++ internal/smf/smf.go | 40 +++++++++++++----------------- internal/smf/teids_pool.go | 22 ++++++---------- internal/smf/upf.go | 34 ++++++++++++------------- 6 files changed, 79 insertions(+), 68 deletions(-) rename internal/{amf => common}/errors.go (69%) create mode 100644 internal/common/with-context.go diff --git a/internal/amf/amf.go b/internal/amf/amf.go index 190ce01..6667897 100644 --- a/internal/amf/amf.go +++ b/internal/amf/amf.go @@ -12,6 +12,7 @@ import ( "net/netip" "time" + "github.com/nextmn/cp-lite/internal/common" "github.com/nextmn/cp-lite/internal/smf" "github.com/nextmn/json-api/healthcheck" @@ -22,15 +23,14 @@ import ( ) type Amf struct { + common.WithContext + control jsonapi.ControlURI client http.Client userAgent string smf *smf.Smf srv *http.Server closed chan struct{} - - // not exported because must not be modified - ctx context.Context } func NewAmf(bindAddr netip.AddrPort, control jsonapi.ControlURI, userAgent string, smf *smf.Smf) *Amf { @@ -62,8 +62,8 @@ func NewAmf(bindAddr netip.AddrPort, control jsonapi.ControlURI, userAgent strin } func (amf *Amf) Start(ctx context.Context) error { - if ctx == nil { - return ErrNilCtx + if err := amf.InitContext(ctx); err != nil { + return err } l, err := net.Listen("tcp", amf.srv.Addr) if err != nil { @@ -106,10 +106,3 @@ func Status(c *gin.Context) { c.Header("Cache-Control", "no-cache") c.JSON(http.StatusOK, status) } - -func (amf *Amf) Context() context.Context { - if amf.ctx != nil { - return amf.ctx - } - return context.Background() -} diff --git a/internal/amf/errors.go b/internal/common/errors.go similarity index 69% rename from internal/amf/errors.go rename to internal/common/errors.go index e6c7e03..949070b 100644 --- a/internal/amf/errors.go +++ b/internal/common/errors.go @@ -1,9 +1,9 @@ -// Copyright 2024 Louis Royer and the NextMN contributors. All rights reserved. +// Copyright Louis Royer and the NextMN contributors. All rights reserved. // Use of this source code is governed by a MIT-style license that can be // found in the LICENSE file. // SPDX-License-Identifier: MIT -package amf +package common import ( "errors" diff --git a/internal/common/with-context.go b/internal/common/with-context.go new file mode 100644 index 0000000..451e41f --- /dev/null +++ b/internal/common/with-context.go @@ -0,0 +1,30 @@ +// Copyright Louis Royer and the NextMN contributors. All rights reserved. +// Use of this source code is governed by a MIT-style license that can be +// found in the LICENSE file. +// SPDX-License-Identifier: MIT + +package common + +import ( + "context" +) + +type WithContext struct { + // not exported because must not be modified + ctx context.Context +} + +func (wc *WithContext) InitContext(ctx context.Context) error { + if ctx == nil { + return ErrNilCtx + } + wc.ctx = ctx + return nil +} + +func (wc *WithContext) Context() context.Context { + if wc.ctx != nil { + return wc.ctx + } + return context.Background() +} diff --git a/internal/smf/smf.go b/internal/smf/smf.go index 394974b..01960a1 100644 --- a/internal/smf/smf.go +++ b/internal/smf/smf.go @@ -10,6 +10,7 @@ import ( "net/netip" "time" + "github.com/nextmn/cp-lite/internal/common" "github.com/nextmn/cp-lite/internal/config" pfcp "github.com/nextmn/go-pfcp-networking/pfcp" @@ -20,15 +21,14 @@ import ( ) type Smf struct { + common.WithContext + upfs *UpfsMap slices *SlicesMap Areas AreasMap srv *pfcp.PFCPEntityCP started bool closed chan struct{} - - // not exported because must not be modified - ctx context.Context } func NewSmf(addr netip.Addr, slices map[string]config.Slice, areas map[string]config.Area) *Smf { @@ -40,7 +40,6 @@ func NewSmf(addr netip.Addr, slices map[string]config.Slice, areas map[string]co upfs: upfs, Areas: NewAreasMap(areas), closed: make(chan struct{}), - ctx: nil, } } @@ -48,10 +47,9 @@ func (smf *Smf) Start(ctx context.Context) error { if smf.started { return ErrSmfAlreadyStarted } - if ctx == nil { - return ErrNilCtx + if err := smf.InitContext(ctx); err != nil { + return err } - smf.ctx = ctx logrus.Info("Starting PFCP Server") go func() { defer func() { @@ -94,15 +92,8 @@ func (smf *Smf) Start(ctx context.Context) error { return nil } -func (smf *Smf) Context() context.Context { - if smf.ctx != nil { - return smf.ctx - } - return context.Background() -} - func (smf *Smf) CreateSessionDownlink(ueCtrl jsonapi.ControlURI, ueIp netip.Addr, dnn string, gnbCtrl jsonapi.ControlURI, gnbFteid jsonapi.Fteid) (*PduSessionN3, error) { - return smf.CreateSessionDownlinkContext(smf.ctx, ueCtrl, ueIp, dnn, gnbCtrl, gnbFteid) + return smf.CreateSessionDownlinkContext(smf.Context(), ueCtrl, ueIp, dnn, gnbCtrl, gnbFteid) } func (smf *Smf) CreateSessionDownlinkContext(ctx context.Context, ueCtrl jsonapi.ControlURI, ueIp netip.Addr, dnn string, gnbCtrl jsonapi.ControlURI, gnbFteid jsonapi.Fteid) (*PduSessionN3, error) { @@ -112,13 +103,14 @@ func (smf *Smf) CreateSessionDownlinkContext(ctx context.Context, ueCtrl jsonapi if ctx == nil { return nil, ErrNilCtx } + smfCtx := smf.Context() select { case <-ctx.Done(): // if ctx is over, abort return nil, ctx.Err() - case <-smf.ctx.Done(): + case <-smfCtx.Done(): // if smf.ctx is over, abort - return nil, smf.ctx.Err() + return nil, smfCtx.Err() default: } // check for existing session @@ -178,13 +170,14 @@ func (smf *Smf) CreateSessionDownlinkFWUpfIContext(ctx context.Context, ueCtrl j if ctx == nil { return nil, ErrNilCtx } + smfCtx := smf.Context() select { case <-ctx.Done(): // if ctx is over, abort return nil, ctx.Err() - case <-smf.ctx.Done(): + case <-smfCtx.Done(): // if smf.ctx is over, abort - return nil, smf.ctx.Err() + return nil, smfCtx.Err() default: } @@ -241,7 +234,7 @@ func (smf *Smf) GetNextUeIpAddr(dnn string) (netip.Addr, error) { } func (smf *Smf) CreateSessionUplink(ueCtrl jsonapi.ControlURI, ueIpAddr netip.Addr, gnbCtrl jsonapi.ControlURI, dnn string) (*PduSessionN3, error) { - return smf.CreateSessionUplinkContext(smf.ctx, ueCtrl, ueIpAddr, gnbCtrl, dnn) + return smf.CreateSessionUplinkContext(smf.Context(), ueCtrl, ueIpAddr, gnbCtrl, dnn) } func (smf *Smf) CreateSessionUplinkContext(ctx context.Context, ueCtrl jsonapi.ControlURI, ueIpAddr netip.Addr, gnbCtrl jsonapi.ControlURI, dnn string) (*PduSessionN3, error) { @@ -251,13 +244,14 @@ func (smf *Smf) CreateSessionUplinkContext(ctx context.Context, ueCtrl jsonapi.C if ctx == nil { return nil, ErrNilCtx } + smfCtx := smf.Context() select { case <-ctx.Done(): // if ctx is over, abort return nil, ctx.Err() - case <-smf.ctx.Done(): + case <-smfCtx.Done(): // if smf.ctx is over, abort - return nil, smf.ctx.Err() + return nil, smfCtx.Err() default: } // check for existing session @@ -398,7 +392,7 @@ func (smf *Smf) GetNextDownlinkFteid(ueCtrl jsonapi.ControlURI, ueAddr netip.Add } func (smf *Smf) UpdateSessionDownlink(ueCtrl jsonapi.ControlURI, ueAddr netip.Addr, dnn string, oldGnbCtrl jsonapi.ControlURI) error { - return smf.UpdateSessionDownlinkContext(smf.ctx, ueCtrl, ueAddr, dnn, oldGnbCtrl) + return smf.UpdateSessionDownlinkContext(smf.Context(), ueCtrl, ueAddr, dnn, oldGnbCtrl) } // Updates Session to NextDownlinkFteid diff --git a/internal/smf/teids_pool.go b/internal/smf/teids_pool.go index 183b30d..e83ac5a 100644 --- a/internal/smf/teids_pool.go +++ b/internal/smf/teids_pool.go @@ -9,14 +9,15 @@ import ( "context" "math/rand" "sync" + + "github.com/nextmn/cp-lite/internal/common" ) type TEIDsPool struct { + common.WithContext + teids map[uint32]struct{} sync.Mutex - - // not exported because must not be modified - ctx context.Context } func NewTEIDsPool() *TEIDsPool { @@ -25,18 +26,11 @@ func NewTEIDsPool() *TEIDsPool { } } -func (t *TEIDsPool) Init(ctx context.Context) error { - if ctx == nil { - return ErrNilCtx - } - t.ctx = ctx - return nil -} - // Returns next TEID from the pool. // warning: the pool must first be initialized using `Init(ctx)` func (t *TEIDsPool) Next(ctx context.Context) (uint32, error) { - if t.ctx == nil || ctx == nil { + tCtx := t.Context() + if ctx == nil { return 0, ErrNilCtx } t.Lock() @@ -46,8 +40,8 @@ func (t *TEIDsPool) Next(ctx context.Context) (uint32, error) { select { case <-ctx.Done(): return 0, ctx.Err() - case <-t.ctx.Done(): - return 0, t.ctx.Err() + case <-tCtx.Done(): + return 0, tCtx.Err() default: teid = rand.Uint32() if teid == 0 { diff --git a/internal/smf/upf.go b/internal/smf/upf.go index 0229083..7de88e3 100644 --- a/internal/smf/upf.go +++ b/internal/smf/upf.go @@ -10,6 +10,7 @@ import ( "net/netip" "sync" + "github.com/nextmn/cp-lite/internal/common" "github.com/nextmn/cp-lite/internal/config" pfcp "github.com/nextmn/go-pfcp-networking/pfcp" @@ -38,12 +39,10 @@ func NewUpfsMap(slices map[string]config.Slice) *UpfsMap { } type Upf struct { + common.WithContext association pfcpapi.PFCPAssociationInterface interfaces map[netip.Addr]*UpfInterface sessions map[netip.Addr]*Pfcprules - - // not exported because must not be modified - ctx context.Context } func NewUpf(interfaces []config.Interface) *Upf { @@ -55,13 +54,12 @@ func NewUpf(interfaces []config.Interface) *Upf { } func (upf *Upf) Associate(ctx context.Context, a pfcpapi.PFCPAssociationInterface) error { - if ctx == nil { - return ErrNilCtx + if err := upf.InitContext(ctx); err != nil { + return err } - upf.ctx = ctx // Initialize TeidPools for _, iface := range upf.interfaces { - if err := iface.Teids.Init(ctx); err != nil { + if err := iface.Teids.InitContext(ctx); err != nil { return err } } @@ -79,18 +77,19 @@ func (upf *Upf) Rules(ueIp netip.Addr) *Pfcprules { } func (upf *Upf) NextListenFteid(listenInterface netip.Addr) (*jsonapi.Fteid, error) { - return upf.NextListenFteidContext(upf.ctx, listenInterface) + return upf.NextListenFteidContext(upf.Context(), listenInterface) } func (upf *Upf) NextListenFteidContext(ctx context.Context, listenInterface netip.Addr) (*jsonapi.Fteid, error) { - if ctx == nil || upf.ctx == nil { + upfCtx := upf.Context() + if ctx == nil { return nil, ErrNilCtx } select { case <-ctx.Done(): return nil, ctx.Err() - case <-upf.ctx.Done(): - return nil, upf.ctx.Err() + case <-upfCtx.Done(): + return nil, upfCtx.Err() default: } iface, ok := upf.interfaces[listenInterface] @@ -108,18 +107,19 @@ func (upf *Upf) NextListenFteidContext(ctx context.Context, listenInterface neti } func (upf *Upf) CreateUplinkIntermediate(ueIp netip.Addr, dnn string, listenInterface netip.Addr, forwardFteid *jsonapi.Fteid) (*jsonapi.Fteid, error) { - return upf.CreateUplinkIntermediateContext(upf.ctx, ueIp, dnn, listenInterface, forwardFteid) + return upf.CreateUplinkIntermediateContext(upf.Context(), ueIp, dnn, listenInterface, forwardFteid) } func (upf *Upf) CreateUplinkIntermediateContext(ctx context.Context, ueIp netip.Addr, dnn string, listenInterface netip.Addr, forwardFteid *jsonapi.Fteid) (*jsonapi.Fteid, error) { - if ctx == nil || upf.ctx == nil { + if ctx == nil { return nil, ErrNilCtx } + upfCtx := upf.Context() select { case <-ctx.Done(): return nil, ctx.Err() - case <-upf.ctx.Done(): - return nil, upf.ctx.Err() + case <-upfCtx.Done(): + return nil, upfCtx.Err() default: } listenFteid, err := upf.NextListenFteidContext(ctx, listenInterface) @@ -163,7 +163,7 @@ func (upf *Upf) CreateUplinkIntermediateWithFteid(ueIp netip.Addr, dnn string, l } func (upf *Upf) CreateUplinkAnchor(ueIp netip.Addr, dnn string, listenInterface netip.Addr) (*jsonapi.Fteid, error) { - return upf.CreateUplinkAnchorContext(upf.ctx, ueIp, dnn, listenInterface) + return upf.CreateUplinkAnchorContext(upf.Context(), ueIp, dnn, listenInterface) } func (upf *Upf) CreateUplinkAnchorContext(ctx context.Context, ueIp netip.Addr, dnn string, listenInterface netip.Addr) (*jsonapi.Fteid, error) { if ctx == nil { @@ -254,7 +254,7 @@ func (upf *Upf) UpdateDownlinkIntermediateDirectForward(ueIp netip.Addr, dnn str } func (upf *Upf) UpdateDownlinkIntermediate(ueIp netip.Addr, dnn string, listenInterface netip.Addr, forwardFteid *jsonapi.Fteid) (*jsonapi.Fteid, uint32, error) { - return upf.UpdateDownlinkIntermediateContext(upf.ctx, ueIp, dnn, listenInterface, forwardFteid) + return upf.UpdateDownlinkIntermediateContext(upf.Context(), ueIp, dnn, listenInterface, forwardFteid) } func (upf *Upf) UpdateDownlinkIntermediateContext(ctx context.Context, ueIp netip.Addr, dnn string, listenInterface netip.Addr, forwardFteid *jsonapi.Fteid) (*jsonapi.Fteid, uint32, error) { if ctx == nil {