Skip to content

Commit

Permalink
fix: internal provider comparison causing race conditions in tests (#312
Browse files Browse the repository at this point in the history
)

fix: internal provider comparison causing race conditions in tests

Signed-off-by: Bernd Warmuth <[email protected]>
Co-authored-by: Todd Baert <[email protected]>
  • Loading branch information
warber and toddbaert authored Jan 21, 2025
1 parent 890bfd0 commit 440072f
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 33 deletions.
41 changes: 8 additions & 33 deletions openfeature/event_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package openfeature

import (
"fmt"
"reflect"
"sync"
"time"

Expand Down Expand Up @@ -68,13 +67,6 @@ type eventPayload struct {
handler FeatureProvider
}

// providerReference is a helper struct to store FeatureProvider with EventHandler capability along with their
// shutdown semaphore
type providerReference struct {
featureProvider FeatureProvider
shutdownSemaphore chan interface{}
}

// AddHandler adds an API(global) level handler
func (e *eventExecutor) AddHandler(t EventType, c EventCallback) {
e.mu.Lock()
Expand Down Expand Up @@ -217,14 +209,7 @@ func (e *eventExecutor) registerDefaultProvider(provider FeatureProvider) error
e.mu.Lock()
defer e.mu.Unlock()

// register shutdown semaphore for new default provider
sem := make(chan interface{})

newProvider := providerReference{
featureProvider: provider,
shutdownSemaphore: sem,
}

newProvider := newProviderRef(provider)
oldProvider := e.defaultProviderReference
e.defaultProviderReference = newProvider

Expand All @@ -235,14 +220,7 @@ func (e *eventExecutor) registerDefaultProvider(provider FeatureProvider) error
func (e *eventExecutor) registerNamedEventingProvider(associatedClient string, provider FeatureProvider) error {
e.mu.Lock()
defer e.mu.Unlock()

// register shutdown semaphore for new named provider
sem := make(chan interface{})

newProvider := providerReference{
featureProvider: provider,
shutdownSemaphore: sem,
}
newProvider := newProviderRef(provider)

oldProvider := e.namedProviderReference[associatedClient]
e.namedProviderReference[associatedClient] = newProvider
Expand Down Expand Up @@ -288,7 +266,7 @@ func (e *eventExecutor) startListeningAndShutdownOld(newProvider providerReferen

// drop from active references
for i, r := range e.activeSubscriptions {
if reflect.DeepEqual(oldReference.featureProvider, r.featureProvider) {
if oldReference.equals(r) {
e.activeSubscriptions = append(e.activeSubscriptions[:i], e.activeSubscriptions[i+1:]...)
}
}
Expand Down Expand Up @@ -332,8 +310,7 @@ func (e *eventExecutor) triggerEvent(event Event, handler FeatureProvider) {

// then run client handlers
for domain, reference := range e.namedProviderReference {
if !reflect.DeepEqual(reference.featureProvider, handler) {
// unassociated client, continue to next
if !reference.equals(newProviderRef(handler)) {
continue
}

Expand All @@ -343,7 +320,7 @@ func (e *eventExecutor) triggerEvent(event Event, handler FeatureProvider) {
}
}

if !reflect.DeepEqual(e.defaultProviderReference.featureProvider, handler) {
if !e.defaultProviderReference.equals(newProviderRef(handler)) {
return
}

Expand Down Expand Up @@ -386,25 +363,23 @@ func (e *eventExecutor) executeHandler(f func(details EventDetails), event Event
// isRunning is a helper till we bump to the latest go version with slices.contains support
func isRunning(provider providerReference, activeProviders []providerReference) bool {
for _, activeProvider := range activeProviders {
if reflect.DeepEqual(activeProvider.featureProvider, provider.featureProvider) {
if activeProvider.equals(provider) {
return true
}
}

return false
}

// isRunning is a helper to check if given provider is already in use
func isBound(provider providerReference, defaultProvider providerReference, namedProviders []providerReference) bool {
if reflect.DeepEqual(provider.featureProvider, defaultProvider.featureProvider) {
if provider.equals(defaultProvider) {
return true
}

for _, namedProvider := range namedProviders {
if reflect.DeepEqual(provider.featureProvider, namedProvider.featureProvider) {
if provider.equals(namedProvider) {
return true
}
}

return false
}
29 changes: 29 additions & 0 deletions openfeature/reference.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package openfeature

import (
"reflect"
)

// newProviderRef creates a new providerReference instance that wraps around a FeatureProvider implementation
func newProviderRef(provider FeatureProvider) providerReference {
return providerReference{
featureProvider: provider,
kind: reflect.TypeOf(provider).Kind(),
shutdownSemaphore: make(chan interface{}),
}
}

// providerReference is a helper struct to store FeatureProvider along with their
// shutdown semaphore
type providerReference struct {
featureProvider FeatureProvider
kind reflect.Kind
shutdownSemaphore chan interface{}
}

func (pr providerReference) equals(other providerReference) bool {
if pr.kind == reflect.Ptr && other.kind == reflect.Ptr {
return pr.featureProvider == other.featureProvider
}
return reflect.DeepEqual(pr.featureProvider, other.featureProvider)
}
69 changes: 69 additions & 0 deletions openfeature/reference_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package openfeature

import (
"testing"
)

func TestProviderReferenceEquals(t *testing.T) {

type myProvider struct {
NoopProvider
field string
}

p1 := myProvider{}
p2 := myProvider{}

tests := []struct {
name string
pr1 providerReference
pr2 providerReference
expected bool
}{

{
name: "both pointers, different instances",
pr1: newProviderRef(&p1),
pr2: newProviderRef(&p2),
expected: false,
},
{
name: "both pointers, same instance",
pr1: newProviderRef(&p1),
pr2: newProviderRef(&p1),
expected: true,
},
{
name: "different pointers, different instance",
pr1: newProviderRef(p1),
pr2: newProviderRef(&p1),
expected: false,
},
{
name: "no pointers, same instance",
pr1: newProviderRef(p1),
pr2: newProviderRef(p1),
expected: true,
},
{
name: "no pointers, different equal instances",
pr1: newProviderRef(myProvider{field: "A"}),
pr2: newProviderRef(myProvider{field: "A"}),
expected: true,
},
{
name: "no pointers, different not equal instances",
pr1: newProviderRef(myProvider{field: "A"}),
pr2: newProviderRef(myProvider{field: "B"}),
expected: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.pr1.equals(tt.pr2); got != tt.expected {
t.Errorf("providerReference.equals() = %v, want %v", got, tt.expected)
}
})
}
}

0 comments on commit 440072f

Please sign in to comment.