Skip to content

Commit

Permalink
Refactor: WithContext type
Browse files Browse the repository at this point in the history
  • Loading branch information
louisroyer committed Feb 20, 2025
1 parent 6777146 commit a6c664a
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 68 deletions.
17 changes: 5 additions & 12 deletions internal/amf/amf.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net/netip"
"time"

"github.com/nextmn/cp-lite/internal/common"
"github.com/nextmn/cp-lite/internal/smf"

"github.com/nextmn/json-api/healthcheck"
Expand All @@ -22,15 +23,14 @@ import (
)

type Amf struct {
common.WithContext

control jsonapi.ControlURI
client http.Client
userAgent string
smf *smf.Smf
srv *http.Server
closed chan struct{}

// not exported because must not be modified
ctx context.Context
}

func NewAmf(bindAddr netip.AddrPort, control jsonapi.ControlURI, userAgent string, smf *smf.Smf) *Amf {
Expand Down Expand Up @@ -62,8 +62,8 @@ func NewAmf(bindAddr netip.AddrPort, control jsonapi.ControlURI, userAgent strin
}

func (amf *Amf) Start(ctx context.Context) error {
if ctx == nil {
return ErrNilCtx
if err := amf.InitContext(ctx); err != nil {
return err
}
l, err := net.Listen("tcp", amf.srv.Addr)
if err != nil {
Expand Down Expand Up @@ -106,10 +106,3 @@ func Status(c *gin.Context) {
c.Header("Cache-Control", "no-cache")
c.JSON(http.StatusOK, status)
}

func (amf *Amf) Context() context.Context {
if amf.ctx != nil {
return amf.ctx
}
return context.Background()
}
4 changes: 2 additions & 2 deletions internal/amf/errors.go → internal/common/errors.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// Copyright 2024 Louis Royer and the NextMN contributors. All rights reserved.
// Copyright Louis Royer and the NextMN contributors. All rights reserved.
// Use of this source code is governed by a MIT-style license that can be
// found in the LICENSE file.
// SPDX-License-Identifier: MIT

package amf
package common

import (
"errors"
Expand Down
30 changes: 30 additions & 0 deletions internal/common/with-context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright Louis Royer and the NextMN contributors. All rights reserved.
// Use of this source code is governed by a MIT-style license that can be
// found in the LICENSE file.
// SPDX-License-Identifier: MIT

package common

import (
"context"
)

type WithContext struct {
// not exported because must not be modified
ctx context.Context
}

func (wc *WithContext) InitContext(ctx context.Context) error {
if ctx == nil {
return ErrNilCtx
}
wc.ctx = ctx
return nil
}

func (wc *WithContext) Context() context.Context {
if wc.ctx != nil {
return wc.ctx
}
return context.Background()
}
40 changes: 17 additions & 23 deletions internal/smf/smf.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net/netip"
"time"

"github.com/nextmn/cp-lite/internal/common"
"github.com/nextmn/cp-lite/internal/config"

pfcp "github.com/nextmn/go-pfcp-networking/pfcp"
Expand All @@ -20,15 +21,14 @@ import (
)

type Smf struct {
common.WithContext

upfs *UpfsMap
slices *SlicesMap
Areas AreasMap
srv *pfcp.PFCPEntityCP
started bool
closed chan struct{}

// not exported because must not be modified
ctx context.Context
}

func NewSmf(addr netip.Addr, slices map[string]config.Slice, areas map[string]config.Area) *Smf {
Expand All @@ -40,18 +40,16 @@ func NewSmf(addr netip.Addr, slices map[string]config.Slice, areas map[string]co
upfs: upfs,
Areas: NewAreasMap(areas),
closed: make(chan struct{}),
ctx: nil,
}
}

func (smf *Smf) Start(ctx context.Context) error {
if smf.started {
return ErrSmfAlreadyStarted
}
if ctx == nil {
return ErrNilCtx
if err := smf.InitContext(ctx); err != nil {
return err
}
smf.ctx = ctx
logrus.Info("Starting PFCP Server")
go func() {
defer func() {
Expand Down Expand Up @@ -94,15 +92,8 @@ func (smf *Smf) Start(ctx context.Context) error {
return nil
}

func (smf *Smf) Context() context.Context {
if smf.ctx != nil {
return smf.ctx
}
return context.Background()
}

func (smf *Smf) CreateSessionDownlink(ueCtrl jsonapi.ControlURI, ueIp netip.Addr, dnn string, gnbCtrl jsonapi.ControlURI, gnbFteid jsonapi.Fteid) (*PduSessionN3, error) {
return smf.CreateSessionDownlinkContext(smf.ctx, ueCtrl, ueIp, dnn, gnbCtrl, gnbFteid)
return smf.CreateSessionDownlinkContext(smf.Context(), ueCtrl, ueIp, dnn, gnbCtrl, gnbFteid)
}

func (smf *Smf) CreateSessionDownlinkContext(ctx context.Context, ueCtrl jsonapi.ControlURI, ueIp netip.Addr, dnn string, gnbCtrl jsonapi.ControlURI, gnbFteid jsonapi.Fteid) (*PduSessionN3, error) {
Expand All @@ -112,13 +103,14 @@ func (smf *Smf) CreateSessionDownlinkContext(ctx context.Context, ueCtrl jsonapi
if ctx == nil {
return nil, ErrNilCtx
}
smfCtx := smf.Context()
select {
case <-ctx.Done():
// if ctx is over, abort
return nil, ctx.Err()
case <-smf.ctx.Done():
case <-smfCtx.Done():
// if smf.ctx is over, abort
return nil, smf.ctx.Err()
return nil, smfCtx.Err()
default:
}
// check for existing session
Expand Down Expand Up @@ -178,13 +170,14 @@ func (smf *Smf) CreateSessionDownlinkFWUpfIContext(ctx context.Context, ueCtrl j
if ctx == nil {
return nil, ErrNilCtx
}
smfCtx := smf.Context()
select {
case <-ctx.Done():
// if ctx is over, abort
return nil, ctx.Err()
case <-smf.ctx.Done():
case <-smfCtx.Done():
// if smf.ctx is over, abort
return nil, smf.ctx.Err()
return nil, smfCtx.Err()
default:
}

Expand Down Expand Up @@ -241,7 +234,7 @@ func (smf *Smf) GetNextUeIpAddr(dnn string) (netip.Addr, error) {
}

func (smf *Smf) CreateSessionUplink(ueCtrl jsonapi.ControlURI, ueIpAddr netip.Addr, gnbCtrl jsonapi.ControlURI, dnn string) (*PduSessionN3, error) {
return smf.CreateSessionUplinkContext(smf.ctx, ueCtrl, ueIpAddr, gnbCtrl, dnn)
return smf.CreateSessionUplinkContext(smf.Context(), ueCtrl, ueIpAddr, gnbCtrl, dnn)
}

func (smf *Smf) CreateSessionUplinkContext(ctx context.Context, ueCtrl jsonapi.ControlURI, ueIpAddr netip.Addr, gnbCtrl jsonapi.ControlURI, dnn string) (*PduSessionN3, error) {
Expand All @@ -251,13 +244,14 @@ func (smf *Smf) CreateSessionUplinkContext(ctx context.Context, ueCtrl jsonapi.C
if ctx == nil {
return nil, ErrNilCtx
}
smfCtx := smf.Context()
select {
case <-ctx.Done():
// if ctx is over, abort
return nil, ctx.Err()
case <-smf.ctx.Done():
case <-smfCtx.Done():
// if smf.ctx is over, abort
return nil, smf.ctx.Err()
return nil, smfCtx.Err()
default:
}
// check for existing session
Expand Down Expand Up @@ -398,7 +392,7 @@ func (smf *Smf) GetNextDownlinkFteid(ueCtrl jsonapi.ControlURI, ueAddr netip.Add
}

func (smf *Smf) UpdateSessionDownlink(ueCtrl jsonapi.ControlURI, ueAddr netip.Addr, dnn string, oldGnbCtrl jsonapi.ControlURI) error {
return smf.UpdateSessionDownlinkContext(smf.ctx, ueCtrl, ueAddr, dnn, oldGnbCtrl)
return smf.UpdateSessionDownlinkContext(smf.Context(), ueCtrl, ueAddr, dnn, oldGnbCtrl)
}

// Updates Session to NextDownlinkFteid
Expand Down
22 changes: 8 additions & 14 deletions internal/smf/teids_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@ import (
"context"
"math/rand"
"sync"

"github.com/nextmn/cp-lite/internal/common"
)

type TEIDsPool struct {
common.WithContext

teids map[uint32]struct{}
sync.Mutex

// not exported because must not be modified
ctx context.Context
}

func NewTEIDsPool() *TEIDsPool {
Expand All @@ -25,18 +26,11 @@ func NewTEIDsPool() *TEIDsPool {
}
}

func (t *TEIDsPool) Init(ctx context.Context) error {
if ctx == nil {
return ErrNilCtx
}
t.ctx = ctx
return nil
}

// Returns next TEID from the pool.
// warning: the pool must first be initialized using `Init(ctx)`
func (t *TEIDsPool) Next(ctx context.Context) (uint32, error) {
if t.ctx == nil || ctx == nil {
tCtx := t.Context()
if ctx == nil {
return 0, ErrNilCtx
}
t.Lock()
Expand All @@ -46,8 +40,8 @@ func (t *TEIDsPool) Next(ctx context.Context) (uint32, error) {
select {
case <-ctx.Done():
return 0, ctx.Err()
case <-t.ctx.Done():
return 0, t.ctx.Err()
case <-tCtx.Done():
return 0, tCtx.Err()
default:
teid = rand.Uint32()
if teid == 0 {
Expand Down
34 changes: 17 additions & 17 deletions internal/smf/upf.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net/netip"
"sync"

"github.com/nextmn/cp-lite/internal/common"
"github.com/nextmn/cp-lite/internal/config"

pfcp "github.com/nextmn/go-pfcp-networking/pfcp"
Expand Down Expand Up @@ -38,12 +39,10 @@ func NewUpfsMap(slices map[string]config.Slice) *UpfsMap {
}

type Upf struct {
common.WithContext
association pfcpapi.PFCPAssociationInterface
interfaces map[netip.Addr]*UpfInterface
sessions map[netip.Addr]*Pfcprules

// not exported because must not be modified
ctx context.Context
}

func NewUpf(interfaces []config.Interface) *Upf {
Expand All @@ -55,13 +54,12 @@ func NewUpf(interfaces []config.Interface) *Upf {
}

func (upf *Upf) Associate(ctx context.Context, a pfcpapi.PFCPAssociationInterface) error {
if ctx == nil {
return ErrNilCtx
if err := upf.InitContext(ctx); err != nil {
return err
}
upf.ctx = ctx
// Initialize TeidPools
for _, iface := range upf.interfaces {
if err := iface.Teids.Init(ctx); err != nil {
if err := iface.Teids.InitContext(ctx); err != nil {
return err
}
}
Expand All @@ -79,18 +77,19 @@ func (upf *Upf) Rules(ueIp netip.Addr) *Pfcprules {
}

func (upf *Upf) NextListenFteid(listenInterface netip.Addr) (*jsonapi.Fteid, error) {
return upf.NextListenFteidContext(upf.ctx, listenInterface)
return upf.NextListenFteidContext(upf.Context(), listenInterface)
}

func (upf *Upf) NextListenFteidContext(ctx context.Context, listenInterface netip.Addr) (*jsonapi.Fteid, error) {
if ctx == nil || upf.ctx == nil {
upfCtx := upf.Context()
if ctx == nil {
return nil, ErrNilCtx
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-upf.ctx.Done():
return nil, upf.ctx.Err()
case <-upfCtx.Done():
return nil, upfCtx.Err()
default:
}
iface, ok := upf.interfaces[listenInterface]
Expand All @@ -108,18 +107,19 @@ func (upf *Upf) NextListenFteidContext(ctx context.Context, listenInterface neti
}

func (upf *Upf) CreateUplinkIntermediate(ueIp netip.Addr, dnn string, listenInterface netip.Addr, forwardFteid *jsonapi.Fteid) (*jsonapi.Fteid, error) {
return upf.CreateUplinkIntermediateContext(upf.ctx, ueIp, dnn, listenInterface, forwardFteid)
return upf.CreateUplinkIntermediateContext(upf.Context(), ueIp, dnn, listenInterface, forwardFteid)
}

func (upf *Upf) CreateUplinkIntermediateContext(ctx context.Context, ueIp netip.Addr, dnn string, listenInterface netip.Addr, forwardFteid *jsonapi.Fteid) (*jsonapi.Fteid, error) {
if ctx == nil || upf.ctx == nil {
if ctx == nil {
return nil, ErrNilCtx
}
upfCtx := upf.Context()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-upf.ctx.Done():
return nil, upf.ctx.Err()
case <-upfCtx.Done():
return nil, upfCtx.Err()
default:
}
listenFteid, err := upf.NextListenFteidContext(ctx, listenInterface)
Expand Down Expand Up @@ -163,7 +163,7 @@ func (upf *Upf) CreateUplinkIntermediateWithFteid(ueIp netip.Addr, dnn string, l
}

func (upf *Upf) CreateUplinkAnchor(ueIp netip.Addr, dnn string, listenInterface netip.Addr) (*jsonapi.Fteid, error) {
return upf.CreateUplinkAnchorContext(upf.ctx, ueIp, dnn, listenInterface)
return upf.CreateUplinkAnchorContext(upf.Context(), ueIp, dnn, listenInterface)
}
func (upf *Upf) CreateUplinkAnchorContext(ctx context.Context, ueIp netip.Addr, dnn string, listenInterface netip.Addr) (*jsonapi.Fteid, error) {
if ctx == nil {
Expand Down Expand Up @@ -254,7 +254,7 @@ func (upf *Upf) UpdateDownlinkIntermediateDirectForward(ueIp netip.Addr, dnn str
}

func (upf *Upf) UpdateDownlinkIntermediate(ueIp netip.Addr, dnn string, listenInterface netip.Addr, forwardFteid *jsonapi.Fteid) (*jsonapi.Fteid, uint32, error) {
return upf.UpdateDownlinkIntermediateContext(upf.ctx, ueIp, dnn, listenInterface, forwardFteid)
return upf.UpdateDownlinkIntermediateContext(upf.Context(), ueIp, dnn, listenInterface, forwardFteid)
}
func (upf *Upf) UpdateDownlinkIntermediateContext(ctx context.Context, ueIp netip.Addr, dnn string, listenInterface netip.Addr, forwardFteid *jsonapi.Fteid) (*jsonapi.Fteid, uint32, error) {
if ctx == nil {
Expand Down

0 comments on commit a6c664a

Please sign in to comment.