From 6af2d20cee965f6af20c6d0490953d7318616662 Mon Sep 17 00:00:00 2001
From: Fred Rolland <frolland@nvidia.com>
Date: Mon, 21 Oct 2024 16:55:03 +0300
Subject: [PATCH] bug: rdma exlusive handling

In case a RDMA device in exclusive mode is in use
by a Pod, the DP was not reporting it as a resource
after DP restart.

Following changes are introduced in RdmaSpec:

- isRdma: in case of no rdma resources,
  check if netlink "enable_rdma" is available.
- GetRdmaDeviceSpec: the device specs are retrieved
  dynamically and not on discovery stage as before.

Dynamic RDMA specs computation vs on discovery, comes
to solve following scenario for exlusive mode:
- Discover RDMA device
- Allocate to Pod (resources are hidden on host)
- Restart DP pod
- Deallocate
- Reallocate

Fixes #565

Signed-off-by: Fred Rolland <frolland@nvidia.com>
---
 pkg/devices/rdma.go                | 80 ++++++++++++++++++++----------
 pkg/devices/rdma_test.go           | 49 ++++++++++++++++--
 pkg/factory/factory.go             | 10 +---
 pkg/factory/factory_test.go        |  6 +++
 pkg/utils/mocks/NetlinkProvider.go | 32 +++++++++++-
 pkg/utils/netlink_provider.go      | 22 ++++++++
 pkg/utils/utils.go                 | 11 ++++
 7 files changed, 169 insertions(+), 41 deletions(-)

diff --git a/pkg/devices/rdma.go b/pkg/devices/rdma.go
index 60d64b429..a77b6af22 100644
--- a/pkg/devices/rdma.go
+++ b/pkg/devices/rdma.go
@@ -18,6 +18,7 @@
 package devices
 
 import (
+	"github.com/golang/glog"
 	pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
 
 	"github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/types"
@@ -25,15 +26,61 @@ import (
 )
 
 type rdmaSpec struct {
-	isSupportRdma bool
-	deviceSpec    []*pluginapi.DeviceSpec
+	deviceID   string
+	deviceType types.DeviceType
 }
 
-func newRdmaSpec(rdmaResources []string) types.RdmaSpec {
+// NewRdmaSpec returns the RdmaSpec
+func NewRdmaSpec(dt types.DeviceType, id string) types.RdmaSpec {
+	if dt == types.AcceleratorType {
+		return nil
+	}
+	return &rdmaSpec{deviceID: id, deviceType: dt}
+}
+
+func (r *rdmaSpec) IsRdma() bool {
+	if len(r.getRdmaResources()) > 0 {
+		return true
+	}
+	var bus string
+	//nolint: exhaustive
+	switch r.deviceType {
+	case types.NetDeviceType:
+		bus = "pci"
+	case types.AuxNetDeviceType:
+		bus = "auxiliary"
+	default:
+		return false
+	}
+	// In case of exclusive RDMA, if the resource is assigned to a pod
+	// the files used to check if the device support RDMA are removed from the host.
+	// In order to still report the resource in this state,
+	// netlink param "enable_rdma" is checked to verify if the device supports RDMA.
+	// This scenario cann happen if the device is discovered, assigned to a pod and then the plugin is restarted.
+	rdma, err := utils.HasRdmaParam(bus, r.deviceID)
+	if err != nil {
+		glog.Infof("HasRdmaParam(): unable to get Netlink RDMA param for device %s : %q", r.deviceID, err)
+		return false
+	}
+	return rdma
+}
+
+func (r *rdmaSpec) getRdmaResources() []string {
+	//nolint: exhaustive
+	switch r.deviceType {
+	case types.NetDeviceType:
+		return utils.GetRdmaProvider().GetRdmaDevicesForPcidev(r.deviceID)
+	case types.AuxNetDeviceType:
+		return utils.GetRdmaProvider().GetRdmaDevicesForAuxdev(r.deviceID)
+	default:
+		return make([]string, 0)
+	}
+}
+
+func (r *rdmaSpec) GetRdmaDeviceSpec() []*pluginapi.DeviceSpec {
+	rdmaResources := r.getRdmaResources()
 	deviceSpec := make([]*pluginapi.DeviceSpec, 0)
-	isSupportRdma := false
 	if len(rdmaResources) > 0 {
-		isSupportRdma = true
 		for _, res := range rdmaResources {
 			resRdmaDevices := utils.GetRdmaProvider().GetRdmaCharDevices(res)
 			for _, rdmaDevice := range resRdmaDevices {
@@ -45,26 +92,5 @@ func newRdmaSpec(rdmaResources []string) types.RdmaSpec {
 			}
 		}
 	}
-
-	return &rdmaSpec{isSupportRdma: isSupportRdma, deviceSpec: deviceSpec}
-}
-
-// NewRdmaSpec returns the RdmaSpec for PCI address
-func NewRdmaSpec(pciAddr string) types.RdmaSpec {
-	rdmaResources := utils.GetRdmaProvider().GetRdmaDevicesForPcidev(pciAddr)
-	return newRdmaSpec(rdmaResources)
-}
-
-// NewAuxRdmaSpec returns the RdmaSpec for auxiliary device ID
-func NewAuxRdmaSpec(deviceID string) types.RdmaSpec {
-	rdmaResources := utils.GetRdmaProvider().GetRdmaDevicesForAuxdev(deviceID)
-	return newRdmaSpec(rdmaResources)
-}
-
-func (r *rdmaSpec) IsRdma() bool {
-	return r.isSupportRdma
-}
-
-func (r *rdmaSpec) GetRdmaDeviceSpec() []*pluginapi.DeviceSpec {
-	return r.deviceSpec
+	return deviceSpec
 }
diff --git a/pkg/devices/rdma_test.go b/pkg/devices/rdma_test.go
index cd321b6d9..e97500a40 100644
--- a/pkg/devices/rdma_test.go
+++ b/pkg/devices/rdma_test.go
@@ -23,6 +23,7 @@ import (
 	pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
 
 	"github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/devices"
+	"github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/types"
 	"github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/utils"
 	"github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/utils/mocks"
 )
@@ -31,16 +32,58 @@ var _ = Describe("RdmaSpec", func() {
 	Describe("creating new RdmaSpec", func() {
 		t := GinkgoT()
 		Context("successfully", func() {
-			It("without device specs", func() {
+			It("without device specs, without netlink enable_rdma param", func() {
+				mockProvider := &mocks.NetlinkProvider{}
+				mockProvider.On("HasRdmaParam", "pci", "0000:00:00.0").Return(false, nil)
+				utils.SetNetlinkProviderInst(mockProvider)
 				fakeRdmaProvider := mocks.RdmaProvider{}
 				fakeRdmaProvider.On("GetRdmaDevicesForPcidev", "0000:00:00.0").Return([]string{})
 				utils.SetRdmaProviderInst(&fakeRdmaProvider)
-				spec := devices.NewRdmaSpec("0000:00:00.0")
+				spec := devices.NewRdmaSpec(types.NetDeviceType, "0000:00:00.0")
 
 				Expect(spec.IsRdma()).To(BeFalse())
 				Expect(spec.GetRdmaDeviceSpec()).To(HaveLen(0))
 				fakeRdmaProvider.AssertExpectations(t)
 			})
+			It("without device specs, with netlink enable_rdma param", func() {
+				mockProvider := &mocks.NetlinkProvider{}
+				mockProvider.On("HasRdmaParam", "pci", "0000:00:00.0").Return(true, nil)
+				utils.SetNetlinkProviderInst(mockProvider)
+				fakeRdmaProvider := mocks.RdmaProvider{}
+				fakeRdmaProvider.On("GetRdmaDevicesForPcidev", "0000:00:00.0").Return([]string{})
+				utils.SetRdmaProviderInst(&fakeRdmaProvider)
+				spec := devices.NewRdmaSpec(types.NetDeviceType, "0000:00:00.0")
+
+				Expect(spec.IsRdma()).To(BeTrue())
+				Expect(spec.GetRdmaDeviceSpec()).To(HaveLen(0))
+				fakeRdmaProvider.AssertExpectations(t)
+			})
+			It("aux without device specs, without netlink enable_rdma param", func() {
+				mockProvider := &mocks.NetlinkProvider{}
+				mockProvider.On("HasRdmaParam", "auxiliary", "mlx5_core.sf.4").Return(false, nil)
+				utils.SetNetlinkProviderInst(mockProvider)
+				fakeRdmaProvider := mocks.RdmaProvider{}
+				fakeRdmaProvider.On("GetRdmaDevicesForAuxdev", "mlx5_core.sf.4").Return([]string{})
+				utils.SetRdmaProviderInst(&fakeRdmaProvider)
+				spec := devices.NewRdmaSpec(types.AuxNetDeviceType, "mlx5_core.sf.4")
+
+				Expect(spec.IsRdma()).To(BeFalse())
+				Expect(spec.GetRdmaDeviceSpec()).To(HaveLen(0))
+				fakeRdmaProvider.AssertExpectations(t)
+			})
+			It("aux without device specs, with netlink enable_rdma param", func() {
+				mockProvider := &mocks.NetlinkProvider{}
+				mockProvider.On("HasRdmaParam", "auxiliary", "mlx5_core.sf.4").Return(true, nil)
+				utils.SetNetlinkProviderInst(mockProvider)
+				fakeRdmaProvider := mocks.RdmaProvider{}
+				fakeRdmaProvider.On("GetRdmaDevicesForAuxdev", "mlx5_core.sf.4").Return([]string{})
+				utils.SetRdmaProviderInst(&fakeRdmaProvider)
+				spec := devices.NewRdmaSpec(types.AuxNetDeviceType, "mlx5_core.sf.4")
+
+				Expect(spec.IsRdma()).To(BeTrue())
+				Expect(spec.GetRdmaDeviceSpec()).To(HaveLen(0))
+				fakeRdmaProvider.AssertExpectations(t)
+			})
 			It("with device specs", func() {
 				fakeRdmaProvider := mocks.RdmaProvider{}
 				fakeRdmaProvider.On("GetRdmaDevicesForPcidev", "0000:00:00.0").
@@ -50,7 +93,7 @@ var _ = Describe("RdmaSpec", func() {
 					"/dev/infiniband/uverbs0", "/dev/infiniband/rdma_cm",
 				}).On("GetRdmaCharDevices", "fake_1").Return([]string{"/dev/infiniband/rdma_cm"})
 				utils.SetRdmaProviderInst(&fakeRdmaProvider)
-				spec := devices.NewRdmaSpec("0000:00:00.0")
+				spec := devices.NewRdmaSpec(types.NetDeviceType, "0000:00:00.0")
 
 				Expect(spec.IsRdma()).To(BeTrue())
 				Expect(spec.GetRdmaDeviceSpec()).To(Equal([]*pluginapi.DeviceSpec{
diff --git a/pkg/factory/factory.go b/pkg/factory/factory.go
index 15a287658..3b3ad938b 100644
--- a/pkg/factory/factory.go
+++ b/pkg/factory/factory.go
@@ -163,15 +163,7 @@ func (rf *resourceFactory) GetResourcePool(rc *types.ResourceConfig, filteredDev
 }
 
 func (rf *resourceFactory) GetRdmaSpec(dt types.DeviceType, deviceID string) types.RdmaSpec {
-	//nolint: exhaustive
-	switch dt {
-	case types.NetDeviceType:
-		return devices.NewRdmaSpec(deviceID)
-	case types.AuxNetDeviceType:
-		return devices.NewAuxRdmaSpec(deviceID)
-	default:
-		return nil
-	}
+	return devices.NewRdmaSpec(dt, deviceID)
 }
 
 func (rf *resourceFactory) GetVdpaDevice(pciAddr string) types.VdpaDevice {
diff --git a/pkg/factory/factory_test.go b/pkg/factory/factory_test.go
index 8d5251627..906c421b7 100644
--- a/pkg/factory/factory_test.go
+++ b/pkg/factory/factory_test.go
@@ -25,10 +25,12 @@ import (
 	"github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/types"
 	"github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/types/mocks"
 	"github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/utils"
+	utilmocks "github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/utils/mocks"
 
 	. "github.com/onsi/ginkgo"
 	. "github.com/onsi/ginkgo/extensions/table"
 	. "github.com/onsi/gomega"
+	"github.com/stretchr/testify/mock"
 	pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
 )
 
@@ -606,6 +608,10 @@ var _ = Describe("Factory", func() {
 	)
 	Describe("getting rdma spec", func() {
 		Context("check c rdma spec", func() {
+			mockProvider := &utilmocks.NetlinkProvider{}
+			mockProvider.On("HasRdmaParam", mock.AnythingOfType("string"),
+				mock.AnythingOfType("string")).Return(false, nil)
+			utils.SetNetlinkProviderInst(mockProvider)
 			f := factory.NewResourceFactory("fake", "fake", true, false)
 			rs1 := f.GetRdmaSpec(types.NetDeviceType, "0000:00:00.1")
 			rs2 := f.GetRdmaSpec(types.AcceleratorType, "0000:00:00.2")
diff --git a/pkg/utils/mocks/NetlinkProvider.go b/pkg/utils/mocks/NetlinkProvider.go
index d9f4c2caf..7c62b1ed7 100644
--- a/pkg/utils/mocks/NetlinkProvider.go
+++ b/pkg/utils/mocks/NetlinkProvider.go
@@ -1,4 +1,4 @@
-// Code generated by mockery v2.43.2. DO NOT EDIT.
+// Code generated by mockery v2.46.3. DO NOT EDIT.
 
 package mocks
 
@@ -43,7 +43,7 @@ func (_m *NetlinkProvider) GetDevLinkDeviceEswitchAttrs(ifName string) (*netlink
 }
 
 // GetDevlinkGetDeviceInfoByNameAsMap provides a mock function with given fields: bus, device
-func (_m *NetlinkProvider) GetDevlinkGetDeviceInfoByNameAsMap(bus string, device string) (map[string]string, error) {
+func (_m *NetlinkProvider) GetDevlinkGetDeviceInfoByNameAsMap(bus, device string) (map[string]string, error) {
 	ret := _m.Called(bus, device)
 
 	if len(ret) == 0 {
@@ -132,6 +132,34 @@ func (_m *NetlinkProvider) GetLinkAttrs(ifName string) (*netlink.LinkAttrs, erro
 	return r0, r1
 }
 
+// HasRdmaParam provides a mock function with given fields: bus, pciAddr
+func (_m *NetlinkProvider) HasRdmaParam(bus, pciAddr string) (bool, error) {
+	ret := _m.Called(bus, pciAddr)
+
+	if len(ret) == 0 {
+		panic("no return value specified for HasRdmaParam")
+	}
+
+	var r0 bool
+	var r1 error
+	if rf, ok := ret.Get(0).(func(string, string) (bool, error)); ok {
+		return rf(bus, pciAddr)
+	}
+	if rf, ok := ret.Get(0).(func(string, string) bool); ok {
+		r0 = rf(bus, pciAddr)
+	} else {
+		r0 = ret.Get(0).(bool)
+	}
+
+	if rf, ok := ret.Get(1).(func(string, string) error); ok {
+		r1 = rf(bus, pciAddr)
+	} else {
+		r1 = ret.Error(1)
+	}
+
+	return r0, r1
+}
+
 // NewNetlinkProvider creates a new instance of NetlinkProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
 // The first argument is typically a *testing.T value.
 func NewNetlinkProvider(t interface {
diff --git a/pkg/utils/netlink_provider.go b/pkg/utils/netlink_provider.go
index a87074d7b..70a0c55ec 100644
--- a/pkg/utils/netlink_provider.go
+++ b/pkg/utils/netlink_provider.go
@@ -31,6 +31,8 @@ type NetlinkProvider interface {
 	GetIPv4RouteList(ifName string) ([]nl.Route, error)
 	// DevlinkGetDeviceInfoByNameAsMap returns devlink info for selected device as a map
 	GetDevlinkGetDeviceInfoByNameAsMap(bus, device string) (map[string]string, error)
+	// HasRdmaParam returns true if device has "enable_rdma" param
+	HasRdmaParam(bus, pciAddr string) (bool, error)
 }
 
 type defaultNetlinkProvider struct {
@@ -48,6 +50,26 @@ func GetNetlinkProvider() NetlinkProvider {
 	return netlinkProvider
 }
 
+// HasRdmaParam returns true if device has "enable_rdma" param
+// equivalent to "devlink dev param show pci/0000:d8:01.1 name enable_rdma"
+// or "devlink dev param show auxiliary/mlx5_core.sf.4 name enable_rdma"
+func (defaultNetlinkProvider) HasRdmaParam(bus, deviceID string) (bool, error) {
+	param, err := nl.DevlinkGetDeviceParamByName(bus, deviceID, "enable_rdma")
+	if err != nil {
+		return false, fmt.Errorf("error getting enable_rdma attribute for device %s on bus %s %v",
+			deviceID, bus, err)
+	}
+	if len(param.Values) == 0 || param.Values[0].Data == nil {
+		return false, nil
+	}
+	var boolValue bool
+	boolValue, ok := param.Values[0].Data.(bool)
+	if !ok {
+		return false, fmt.Errorf("value is not a bool")
+	}
+	return boolValue, nil
+}
+
 // GetLinkAttrs returns a net device's link attributes.
 func (defaultNetlinkProvider) GetLinkAttrs(ifName string) (*nl.LinkAttrs, error) {
 	link, err := nl.LinkByName(ifName)
diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go
index 262525035..4d5b68609 100644
--- a/pkg/utils/utils.go
+++ b/pkg/utils/utils.go
@@ -474,6 +474,17 @@ func GetPfEswitchMode(pciAddr string) (string, error) {
 	return devLinkDeviceAttrs.Mode, nil
 }
 
+// HasRdmaParam returns true if deviceID has "enable_rdma" param
+// for example: pci 0000:d8:01.1
+// or auxiliary mlx5_core.sf.4
+func HasRdmaParam(bus, deviceID string) (bool, error) {
+	rdma, err := GetNetlinkProvider().HasRdmaParam(bus, deviceID)
+	if err != nil {
+		return false, err
+	}
+	return rdma, nil
+}
+
 // HasDefaultRoute returns true if PCI network device is default route interface
 func HasDefaultRoute(pciAddr string) (bool, error) {
 	// Get net interface name