diff --git a/Makefile b/Makefile index 8b4064819..cfe7b2d85 100644 --- a/Makefile +++ b/Makefile @@ -68,6 +68,9 @@ uninstall: lint: go install -v github.com/golangci/golangci-lint/cmd/golangci-lint@v1.55.2 $(foreach module, $(ALL_GO_MOD_DIRS), ${GOPATH}/bin/golangci-lint run --deadline=5m --timeout=5m $(module)/... || exit;) +lint-fix: + go install -v github.com/golangci/golangci-lint/cmd/golangci-lint@v1.55.2 + $(foreach module, $(ALL_GO_MOD_DIRS), ${GOPATH}/bin/golangci-lint run --fix --deadline=5m --timeout=5m $(module)/... || exit;) install-mockgen: go install go.uber.org/mock/mockgen@v0.4.0 mockgen: install-mockgen diff --git a/core/pkg/service/iservice.go b/core/pkg/service/iservice.go index 20151ac51..a373778b6 100644 --- a/core/pkg/service/iservice.go +++ b/core/pkg/service/iservice.go @@ -32,6 +32,7 @@ type Configuration struct { SocketPath string CORS []string Options []connect.HandlerOption + ContextValues map[string]any } /* diff --git a/docs/reference/flag-definitions.md b/docs/reference/flag-definitions.md index 5b7f492f0..7d32f39b8 100644 --- a/docs/reference/flag-definitions.md +++ b/docs/reference/flag-definitions.md @@ -184,6 +184,9 @@ For example, when accessing flagd via HTTP, the POST body may look like this: The evaluation context can be accessed in targeting rules using the `var` operation followed by the evaluation context property name. +The evaluation context can be appended by arbitrary key value pairs +via the `-X` command line flag. + | Description | Example | | -------------------------------------------------------------- | ---------------------------------------------------- | | Retrieve property from the evaluation context | `#!json { "var": "email" }` | diff --git a/docs/reference/flagd-cli/flagd_start.md b/docs/reference/flagd-cli/flagd_start.md index a9d33d3bb..c8f784db3 100644 --- a/docs/reference/flagd-cli/flagd_start.md +++ b/docs/reference/flagd-cli/flagd_start.md @@ -11,6 +11,7 @@ flagd start [flags] ### Options ``` + -X, --context-value stringToString add arbitrary key value pairs to the flag evaluation context (default []) -C, --cors-origin strings CORS allowed origins, * will allow all origins -h, --help help for start -z, --log-format string Set the logging format, e.g. console or json (default "console") diff --git a/flagd/cmd/start.go b/flagd/cmd/start.go index 2e952549a..7cc16141a 100644 --- a/flagd/cmd/start.go +++ b/flagd/cmd/start.go @@ -34,11 +34,11 @@ const ( sourcesFlagName = "sources" syncPortFlagName = "sync-port" uriFlagName = "uri" + contextValueFlagName = "context-value" ) func init() { flags := startCmd.Flags() - // allows environment variables to use _ instead of - viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) // sync-provider-args becomes SYNC_PROVIDER_ARGS viper.SetEnvPrefix("FLAGD") // port becomes FLAGD_PORT @@ -78,6 +78,8 @@ func init() { flags.StringP(otelCAPathFlagName, "A", "", "tls certificate authority path to use with OpenTelemetry collector") flags.DurationP(otelReloadIntervalFlagName, "I", time.Hour, "how long between reloading the otel tls certificate "+ "from disk") + flags.StringToStringP(contextValueFlagName, "X", map[string]string{}, "add arbitrary key value pairs "+ + "to the flag evaluation context") _ = viper.BindPFlag(corsFlagName, flags.Lookup(corsFlagName)) _ = viper.BindPFlag(logFormatFlagName, flags.Lookup(logFormatFlagName)) @@ -95,6 +97,7 @@ func init() { _ = viper.BindPFlag(uriFlagName, flags.Lookup(uriFlagName)) _ = viper.BindPFlag(syncPortFlagName, flags.Lookup(syncPortFlagName)) _ = viper.BindPFlag(ofrepPortFlagName, flags.Lookup(ofrepPortFlagName)) + _ = viper.BindPFlag(contextValueFlagName, flags.Lookup(contextValueFlagName)) } // startCmd represents the start command @@ -139,6 +142,11 @@ var startCmd = &cobra.Command{ } syncProviders = append(syncProviders, syncProvidersFromConfig...) + contextValuesToMap := make(map[string]any) + for k, v := range viper.GetStringMapString(contextValueFlagName) { + contextValuesToMap[k] = v + } + // Build Runtime ----------------------------------------------------------- rt, err := runtime.FromConfig(logger, Version, runtime.Config{ CORS: viper.GetStringSlice(corsFlagName), @@ -156,6 +164,7 @@ var startCmd = &cobra.Command{ ServiceSocketPath: viper.GetString(socketPathFlagName), SyncServicePort: viper.GetUint16(syncPortFlagName), SyncProviders: syncProviders, + ContextValues: contextValuesToMap, }) if err != nil { rtLogger.Fatal(err.Error()) diff --git a/flagd/pkg/runtime/from_config.go b/flagd/pkg/runtime/from_config.go index 010a6023e..8bdc6a039 100644 --- a/flagd/pkg/runtime/from_config.go +++ b/flagd/pkg/runtime/from_config.go @@ -40,6 +40,8 @@ type Config struct { SyncProviders []sync.SourceConfig CORS []string + + ContextValues map[string]any } // FromConfig builds a runtime from startup configurations @@ -101,17 +103,20 @@ func FromConfig(logger *logger.Logger, version string, config Config) (*Runtime, ofrepService, err := ofrep.NewOfrepService(jsonEvaluator, config.CORS, ofrep.SvcConfiguration{ Logger: logger.WithFields(zap.String("component", "OFREPService")), Port: config.OfrepServicePort, - }) + }, + config.ContextValues, + ) if err != nil { return nil, fmt.Errorf("error creating ofrep service") } // flag sync service flagSyncService, err := flagsync.NewSyncService(flagsync.SvcConfigurations{ - Logger: logger.WithFields(zap.String("component", "FlagSyncService")), - Port: config.SyncServicePort, - Sources: sources, - Store: s, + Logger: logger.WithFields(zap.String("component", "FlagSyncService")), + Port: config.SyncServicePort, + Sources: sources, + Store: s, + ContextValues: config.ContextValues, }) if err != nil { return nil, fmt.Errorf("error creating sync service: %w", err) @@ -145,6 +150,7 @@ func FromConfig(logger *logger.Logger, version string, config Config) (*Runtime, SocketPath: config.ServiceSocketPath, CORS: config.CORS, Options: options, + ContextValues: config.ContextValues, }, SyncImpl: iSyncs, }, nil diff --git a/flagd/pkg/service/flag-evaluation/connect_service.go b/flagd/pkg/service/flag-evaluation/connect_service.go index 799d2546c..f7c303b74 100644 --- a/flagd/pkg/service/flag-evaluation/connect_service.go +++ b/flagd/pkg/service/flag-evaluation/connect_service.go @@ -154,6 +154,7 @@ func (s *ConnectService) setupServer(svcConf service.Configuration) (net.Listene s.eval, s.eventingConfiguration, s.metrics, + svcConf.ContextValues, ) marshalOpts := WithJSON( @@ -170,6 +171,7 @@ func (s *ConnectService) setupServer(svcConf service.Configuration) (net.Listene s.eval, s.eventingConfiguration, s.metrics, + svcConf.ContextValues, ) _, newHandler := evaluationV1.NewServiceHandler(newFes, append(svcConf.Options, marshalOpts)...) diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator.go b/flagd/pkg/service/flag-evaluation/flag_evaluator.go index 29b4c229e..7825e2a5a 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator.go @@ -32,11 +32,16 @@ type OldFlagEvaluationService struct { metrics telemetry.IMetricsRecorder eventingConfiguration IEvents flagEvalTracer trace.Tracer + contextValues map[string]any } // NewOldFlagEvaluationService creates a OldFlagEvaluationService with provided parameters -func NewOldFlagEvaluationService(log *logger.Logger, - eval evaluator.IEvaluator, eventingCfg IEvents, metricsRecorder telemetry.IMetricsRecorder, +func NewOldFlagEvaluationService( + log *logger.Logger, + eval evaluator.IEvaluator, + eventingCfg IEvents, + metricsRecorder telemetry.IMetricsRecorder, + contextValues map[string]any, ) *OldFlagEvaluationService { svc := &OldFlagEvaluationService{ logger: log, @@ -44,6 +49,7 @@ func NewOldFlagEvaluationService(log *logger.Logger, metrics: &telemetry.NoopMetricsRecorder{}, eventingConfiguration: eventingCfg, flagEvalTracer: otel.Tracer("flagEvaluationService"), + contextValues: contextValues, } if metricsRecorder != nil { @@ -65,12 +71,8 @@ func (s *OldFlagEvaluationService) ResolveAll( res := &schemaV1.ResolveAllResponse{ Flags: make(map[string]*schemaV1.AnyFlag), } - evalCtx := map[string]any{} - if e := req.Msg.GetContext(); e != nil { - evalCtx = e.AsMap() - } - values, err := s.eval.ResolveAllValues(sCtx, reqID, evalCtx) + values, err := s.eval.ResolveAllValues(sCtx, reqID, mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues)) if err != nil { s.logger.WarnWithID(reqID, fmt.Sprintf("error resolving all flags: %v", err)) return nil, fmt.Errorf("error resolving flags. Tracking ID: %s", reqID) @@ -172,6 +174,7 @@ func (s *OldFlagEvaluationService) ResolveBoolean( sCtx, span := s.flagEvalTracer.Start(ctx, "resolveBoolean", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() res := connect.NewResponse(&schemaV1.ResolveBooleanResponse{}) + err := resolve[bool]( sCtx, s.logger, @@ -180,6 +183,7 @@ func (s *OldFlagEvaluationService) ResolveBoolean( req.Msg.GetContext(), &booleanResponse{schemaV1Resp: res}, s.metrics, + s.contextValues, ) if err != nil { span.RecordError(err) @@ -206,6 +210,7 @@ func (s *OldFlagEvaluationService) ResolveString( req.Msg.GetContext(), &stringResponse{schemaV1Resp: res}, s.metrics, + s.contextValues, ) if err != nil { span.RecordError(err) @@ -232,6 +237,7 @@ func (s *OldFlagEvaluationService) ResolveInt( req.Msg.GetContext(), &intResponse{schemaV1Resp: res}, s.metrics, + s.contextValues, ) if err != nil { span.RecordError(err) @@ -258,6 +264,7 @@ func (s *OldFlagEvaluationService) ResolveFloat( req.Msg.GetContext(), &floatResponse{schemaV1Resp: res}, s.metrics, + s.contextValues, ) if err != nil { span.RecordError(err) @@ -284,6 +291,7 @@ func (s *OldFlagEvaluationService) ResolveObject( req.Msg.GetContext(), &objectResponse{schemaV1Resp: res}, s.metrics, + s.contextValues, ) if err != nil { span.RecordError(err) @@ -293,21 +301,36 @@ func (s *OldFlagEvaluationService) ResolveObject( return res, err } +// mergeContexts combines values from the request context with the values from the config --context-values flag. +// Request context values have a higher priority. +func mergeContexts(reqCtx, configFlagsCtx map[string]any) map[string]any { + merged := make(map[string]any) + for k, v := range reqCtx { + merged[k] = v + } + for k, v := range configFlagsCtx { + merged[k] = v + } + return merged +} + // resolve is a generic flag resolver func resolve[T constraints](ctx context.Context, logger *logger.Logger, resolver resolverSignature[T], flagKey string, evaluationContext *structpb.Struct, resp response[T], metrics telemetry.IMetricsRecorder, + configContextValues map[string]any, ) error { reqID := xid.New().String() defer logger.ClearFields(reqID) + mergedContext := mergeContexts(evaluationContext.AsMap(), configContextValues) logger.WriteFields( reqID, zap.String("flag-key", flagKey), - zap.Strings("context-keys", formatContextKeys(evaluationContext)), + zap.Strings("context-keys", formatContextKeys(mergedContext)), ) var evalErrFormatted error - result, variant, reason, metadata, evalErr := resolver(ctx, reqID, flagKey, evaluationContext.AsMap()) + result, variant, reason, metadata, evalErr := resolver(ctx, reqID, flagKey, mergedContext) if evalErr != nil { logger.WarnWithID(reqID, fmt.Sprintf("returning error response, reason: %v", evalErr)) reason = model.ErrorReason @@ -329,9 +352,9 @@ func resolve[T constraints](ctx context.Context, logger *logger.Logger, resolver return evalErrFormatted } -func formatContextKeys(context *structpb.Struct) []string { +func formatContextKeys(context map[string]any) []string { res := []string{} - for k := range context.AsMap() { + for k := range context { res = append(res, k) } return res diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator_test.go b/flagd/pkg/service/flag-evaluation/flag_evaluator_test.go index d85a4ade7..127875956 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator_test.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator_test.go @@ -128,6 +128,7 @@ func TestConnectService_ResolveAll(t *testing.T) { eval, &eventingConfiguration{}, metrics, + nil, ) got, err := s.ResolveAll(context.Background(), connect.NewRequest(tt.req)) if err != nil && !errors.Is(err, tt.wantErr) { @@ -235,6 +236,7 @@ func TestFlag_Evaluation_ResolveBoolean(t *testing.T) { eval, &eventingConfiguration{}, metrics, + nil, ) got, err := s.ResolveBoolean(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req)) if (err != nil) && !errors.Is(err, tt.wantErr) { @@ -290,6 +292,7 @@ func BenchmarkFlag_Evaluation_ResolveBoolean(b *testing.B) { eval, &eventingConfiguration{}, metrics, + nil, ) b.Run(name, func(b *testing.B) { for i := 0; i < b.N; i++ { @@ -388,6 +391,7 @@ func TestFlag_Evaluation_ResolveString(t *testing.T) { eval, &eventingConfiguration{}, metrics, + nil, ) got, err := s.ResolveString(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req)) if (err != nil) && !errors.Is(err, tt.wantErr) { @@ -443,6 +447,7 @@ func BenchmarkFlag_Evaluation_ResolveString(b *testing.B) { eval, &eventingConfiguration{}, metrics, + nil, ) b.Run(name, func(b *testing.B) { for i := 0; i < b.N; i++ { @@ -540,6 +545,7 @@ func TestFlag_Evaluation_ResolveFloat(t *testing.T) { eval, &eventingConfiguration{}, metrics, + nil, ) got, err := s.ResolveFloat(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req)) if (err != nil) && !errors.Is(err, tt.wantErr) { @@ -595,6 +601,7 @@ func BenchmarkFlag_Evaluation_ResolveFloat(b *testing.B) { eval, &eventingConfiguration{}, metrics, + nil, ) b.Run(name, func(b *testing.B) { for i := 0; i < b.N; i++ { @@ -692,6 +699,7 @@ func TestFlag_Evaluation_ResolveInt(t *testing.T) { eval, &eventingConfiguration{}, metrics, + nil, ) got, err := s.ResolveInt(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req)) if (err != nil) && !errors.Is(err, tt.wantErr) { @@ -747,6 +755,7 @@ func BenchmarkFlag_Evaluation_ResolveInt(b *testing.B) { eval, &eventingConfiguration{}, metrics, + nil, ) b.Run(name, func(b *testing.B) { for i := 0; i < b.N; i++ { @@ -847,6 +856,7 @@ func TestFlag_Evaluation_ResolveObject(t *testing.T) { eval, &eventingConfiguration{}, metrics, + nil, ) outParsed, err := structpb.NewStruct(tt.evalFields.result) @@ -910,6 +920,7 @@ func BenchmarkFlag_Evaluation_ResolveObject(b *testing.B) { eval, &eventingConfiguration{}, metrics, + nil, ) if name != "eval returns error" { outParsed, err := structpb.NewStruct(tt.evalFields.result) diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go b/flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go index 277976f91..ebc8685b8 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go @@ -25,6 +25,7 @@ type FlagEvaluationService struct { metrics telemetry.IMetricsRecorder eventingConfiguration IEvents flagEvalTracer trace.Tracer + contextValues map[string]any } // NewFlagEvaluationService creates a FlagEvaluationService with provided parameters @@ -32,6 +33,7 @@ func NewFlagEvaluationService(log *logger.Logger, eval evaluator.IEvaluator, eventingCfg IEvents, metricsRecorder telemetry.IMetricsRecorder, + contextValues map[string]any, ) *FlagEvaluationService { svc := &FlagEvaluationService{ logger: log, @@ -39,6 +41,7 @@ func NewFlagEvaluationService(log *logger.Logger, metrics: &telemetry.NoopMetricsRecorder{}, eventingConfiguration: eventingCfg, flagEvalTracer: otel.Tracer("flagd.evaluation.v1"), + contextValues: contextValues, } if metricsRecorder != nil { @@ -63,12 +66,7 @@ func (s *FlagEvaluationService) ResolveAll( Flags: make(map[string]*evalV1.AnyFlag), } - evalCtx := map[string]any{} - if e := req.Msg.GetContext(); e != nil { - evalCtx = e.AsMap() - } - - values, err := s.eval.ResolveAllValues(sCtx, reqID, evalCtx) + values, err := s.eval.ResolveAllValues(sCtx, reqID, mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues)) if err != nil { s.logger.WarnWithID(reqID, fmt.Sprintf("error resolving all flags: %v", err)) return nil, fmt.Errorf("error resolving flags. Tracking ID: %s", reqID) @@ -167,8 +165,9 @@ func (s *FlagEvaluationService) ResolveBoolean( ) (*connect.Response[evalV1.ResolveBooleanResponse], error) { sCtx, span := s.flagEvalTracer.Start(ctx, "resolveBoolean", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() + res := connect.NewResponse(&evalV1.ResolveBooleanResponse{}) - err := resolve[bool]( + err := resolve( sCtx, s.logger, s.eval.ResolveBooleanValue, @@ -176,6 +175,7 @@ func (s *FlagEvaluationService) ResolveBoolean( req.Msg.GetContext(), &booleanResponse{evalV1Resp: res}, s.metrics, + s.contextValues, ) if err != nil { span.RecordError(err) @@ -193,7 +193,7 @@ func (s *FlagEvaluationService) ResolveString( defer span.End() res := connect.NewResponse(&evalV1.ResolveStringResponse{}) - err := resolve[string]( + err := resolve( sCtx, s.logger, s.eval.ResolveStringValue, @@ -201,6 +201,7 @@ func (s *FlagEvaluationService) ResolveString( req.Msg.GetContext(), &stringResponse{evalV1Resp: res}, s.metrics, + s.contextValues, ) if err != nil { span.RecordError(err) @@ -218,7 +219,7 @@ func (s *FlagEvaluationService) ResolveInt( defer span.End() res := connect.NewResponse(&evalV1.ResolveIntResponse{}) - err := resolve[int64]( + err := resolve( sCtx, s.logger, s.eval.ResolveIntValue, @@ -226,6 +227,7 @@ func (s *FlagEvaluationService) ResolveInt( req.Msg.GetContext(), &intResponse{evalV1Resp: res}, s.metrics, + s.contextValues, ) if err != nil { span.RecordError(err) @@ -243,7 +245,7 @@ func (s *FlagEvaluationService) ResolveFloat( defer span.End() res := connect.NewResponse(&evalV1.ResolveFloatResponse{}) - err := resolve[float64]( + err := resolve( sCtx, s.logger, s.eval.ResolveFloatValue, @@ -251,6 +253,7 @@ func (s *FlagEvaluationService) ResolveFloat( req.Msg.GetContext(), &floatResponse{evalV1Resp: res}, s.metrics, + s.contextValues, ) if err != nil { span.RecordError(err) @@ -268,7 +271,7 @@ func (s *FlagEvaluationService) ResolveObject( defer span.End() res := connect.NewResponse(&evalV1.ResolveObjectResponse{}) - err := resolve[map[string]any]( + err := resolve( sCtx, s.logger, s.eval.ResolveObjectValue, @@ -276,6 +279,7 @@ func (s *FlagEvaluationService) ResolveObject( req.Msg.GetContext(), &objectResponse{evalV1Resp: res}, s.metrics, + s.contextValues, ) if err != nil { span.RecordError(err) diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator_v2_test.go b/flagd/pkg/service/flag-evaluation/flag_evaluator_v2_test.go index e77b67fa9..83eedce14 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator_v2_test.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator_v2_test.go @@ -3,6 +3,7 @@ package service import ( "context" "errors" + "reflect" "testing" evalV1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/flagd/evaluation/v1" @@ -93,7 +94,7 @@ func TestConnectServiceV2_ResolveAll(t *testing.T) { ).AnyTimes() metrics, exp := getMetricReader() - s := NewFlagEvaluationService(logger.NewLogger(nil, false), eval, &eventingConfiguration{}, metrics) + s := NewFlagEvaluationService(logger.NewLogger(nil, false), eval, &eventingConfiguration{}, metrics, nil) // when got, err := s.ResolveAll(context.Background(), connect.NewRequest(tt.req)) @@ -208,6 +209,7 @@ func TestFlag_EvaluationV2_ResolveBoolean(t *testing.T) { eval, &eventingConfiguration{}, metrics, + nil, ) got, err := s.ResolveBoolean(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req)) if (err != nil) && !errors.Is(err, tt.wantErr) { @@ -263,6 +265,7 @@ func BenchmarkFlag_EvaluationV2_ResolveBoolean(b *testing.B) { eval, &eventingConfiguration{}, metrics, + nil, ) b.Run(name, func(b *testing.B) { for i := 0; i < b.N; i++ { @@ -361,6 +364,7 @@ func TestFlag_EvaluationV2_ResolveString(t *testing.T) { eval, &eventingConfiguration{}, metrics, + nil, ) got, err := s.ResolveString(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req)) if (err != nil) && !errors.Is(err, tt.wantErr) { @@ -416,6 +420,7 @@ func BenchmarkFlag_EvaluationV2_ResolveString(b *testing.B) { eval, &eventingConfiguration{}, metrics, + nil, ) b.Run(name, func(b *testing.B) { for i := 0; i < b.N; i++ { @@ -513,6 +518,7 @@ func TestFlag_EvaluationV2_ResolveFloat(t *testing.T) { eval, &eventingConfiguration{}, metrics, + nil, ) got, err := s.ResolveFloat(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req)) if (err != nil) && !errors.Is(err, tt.wantErr) { @@ -568,6 +574,7 @@ func BenchmarkFlag_EvaluationV2_ResolveFloat(b *testing.B) { eval, &eventingConfiguration{}, metrics, + nil, ) b.Run(name, func(b *testing.B) { for i := 0; i < b.N; i++ { @@ -665,6 +672,7 @@ func TestFlag_EvaluationV2_ResolveInt(t *testing.T) { eval, &eventingConfiguration{}, metrics, + nil, ) got, err := s.ResolveInt(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req)) if (err != nil) && !errors.Is(err, tt.wantErr) { @@ -720,6 +728,7 @@ func BenchmarkFlag_EvaluationV2_ResolveInt(b *testing.B) { eval, &eventingConfiguration{}, metrics, + nil, ) b.Run(name, func(b *testing.B) { for i := 0; i < b.N; i++ { @@ -820,6 +829,7 @@ func TestFlag_EvaluationV2_ResolveObject(t *testing.T) { eval, &eventingConfiguration{}, metrics, + nil, ) outParsed, err := structpb.NewStruct(tt.evalFields.result) @@ -883,6 +893,7 @@ func BenchmarkFlag_EvaluationV2_ResolveObject(b *testing.B) { eval, &eventingConfiguration{}, metrics, + nil, ) if name != "eval returns error" { outParsed, err := structpb.NewStruct(tt.evalFields.result) @@ -955,3 +966,35 @@ func TestFlag_EvaluationV2_ErrorCodes(t *testing.T) { } } } + +func Test_mergeContexts(t *testing.T) { + type args struct { + clientContext, configContext map[string]any + } + + tests := []struct { + name string + args args + want map[string]any + }{ + { + name: "merge contexts", + args: args{ + clientContext: map[string]any{"k1": "v1", "k2": "v2"}, + configContext: map[string]any{"k2": "v22", "k3": "v3"}, + }, + // static context should "win" + want: map[string]any{"k1": "v1", "k2": "v22", "k3": "v3"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := mergeContexts(tt.args.clientContext, tt.args.configContext) + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("\ngot: %+v\nwant: %+v", got, tt.want) + } + }) + } +} diff --git a/flagd/pkg/service/flag-evaluation/ofrep/handler.go b/flagd/pkg/service/flag-evaluation/ofrep/handler.go index f2bee4907..bc97390ba 100644 --- a/flagd/pkg/service/flag-evaluation/ofrep/handler.go +++ b/flagd/pkg/service/flag-evaluation/ofrep/handler.go @@ -20,14 +20,16 @@ const ( ) type handler struct { - Logger *logger.Logger - evaluator evaluator.IEvaluator + Logger *logger.Logger + evaluator evaluator.IEvaluator + contextValues map[string]any } -func NewOfrepHandler(logger *logger.Logger, evaluator evaluator.IEvaluator) http.Handler { +func NewOfrepHandler(logger *logger.Logger, evaluator evaluator.IEvaluator, contextValues map[string]any) http.Handler { h := handler{ - logger, - evaluator, + Logger: logger, + evaluator: evaluator, + contextValues: contextValues, } router := mux.NewRouter() @@ -56,7 +58,7 @@ func (h *handler) HandleFlagEvaluation(w http.ResponseWriter, r *http.Request) { return } - context := flagdContext(h.Logger, requestID, request) + context := flagdContext(h.Logger, requestID, request, h.contextValues) evaluation := h.evaluator.ResolveAsAnyValue(r.Context(), requestID, flagKey, context) if evaluation.Error != nil { status, evaluationError := ofrep.EvaluationErrorResponseFrom(evaluation) @@ -76,7 +78,7 @@ func (h *handler) HandleBulkEvaluation(w http.ResponseWriter, r *http.Request) { return } - context := flagdContext(h.Logger, requestID, request) + context := flagdContext(h.Logger, requestID, request, h.contextValues) evaluations, err := h.evaluator.ResolveAllValues(r.Context(), requestID, context) if err != nil { h.Logger.WarnWithID(requestID, fmt.Sprintf("error from resolver: %v", err)) @@ -117,13 +119,21 @@ func extractOfrepRequest(req *http.Request) (ofrep.Request, error) { return request, nil } -func flagdContext(log *logger.Logger, requestID string, request ofrep.Request) map[string]any { - context := map[string]any{} +func flagdContext( + log *logger.Logger, requestID string, request ofrep.Request, staticContextValues map[string]any, +) map[string]any { + context := make(map[string]any) if res, ok := request.Context.(map[string]any); ok { - context = res + for k, v := range res { + context[k] = v + } } else { log.WarnWithID(requestID, "provided context does not comply with flagd, continuing ignoring the context") } + for k, v := range staticContextValues { + context[k] = v + } + return context } diff --git a/flagd/pkg/service/flag-evaluation/ofrep/ofrep_service.go b/flagd/pkg/service/flag-evaluation/ofrep/ofrep_service.go index 2bff464ca..68169307c 100644 --- a/flagd/pkg/service/flag-evaluation/ofrep/ofrep_service.go +++ b/flagd/pkg/service/flag-evaluation/ofrep/ofrep_service.go @@ -29,12 +29,14 @@ type Service struct { server *http.Server } -func NewOfrepService(evaluator evaluator.IEvaluator, origins []string, cfg SvcConfiguration) (*Service, error) { +func NewOfrepService( + evaluator evaluator.IEvaluator, origins []string, cfg SvcConfiguration, contextValues map[string]any, +) (*Service, error) { corsMW := cors.New(cors.Options{ AllowedOrigins: origins, AllowedMethods: []string{http.MethodPost}, }) - h := corsMW.Handler(NewOfrepHandler(cfg.Logger, evaluator)) + h := corsMW.Handler(NewOfrepHandler(cfg.Logger, evaluator, contextValues)) server := http.Server{ Addr: fmt.Sprintf(":%d", cfg.Port), diff --git a/flagd/pkg/service/flag-evaluation/ofrep/ofrep_service_test.go b/flagd/pkg/service/flag-evaluation/ofrep/ofrep_service_test.go index 37479e25b..0afcce00d 100644 --- a/flagd/pkg/service/flag-evaluation/ofrep/ofrep_service_test.go +++ b/flagd/pkg/service/flag-evaluation/ofrep/ofrep_service_test.go @@ -27,7 +27,7 @@ func Test_OfrepServiceStartStop(t *testing.T) { Port: uint16(port), } - service, err := NewOfrepService(eval, []string{"*"}, cfg) + service, err := NewOfrepService(eval, []string{"*"}, cfg, nil) if err != nil { t.Fatalf("error creating the ofrep service: %v", err) } diff --git a/flagd/pkg/service/flag-sync/handler.go b/flagd/pkg/service/flag-sync/handler.go index ba7afcca6..d971ca6ca 100644 --- a/flagd/pkg/service/flag-sync/handler.go +++ b/flagd/pkg/service/flag-sync/handler.go @@ -5,15 +5,16 @@ import ( "fmt" "buf.build/gen/go/open-feature/flagd/grpc/go/flagd/sync/v1/syncv1grpc" - "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/flagd/sync/v1" + syncv1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/flagd/sync/v1" "github.com/open-feature/flagd/core/pkg/logger" "google.golang.org/protobuf/types/known/structpb" ) // syncHandler implements the sync contract type syncHandler struct { - mux *Multiplexer - log *logger.Logger + mux *Multiplexer + log *logger.Logger + contextValues map[string]any } func (s syncHandler) SyncFlags(req *syncv1.SyncFlagsRequest, server syncv1grpc.FlagSyncService_SyncFlagsServer) error { @@ -59,9 +60,15 @@ func (s syncHandler) FetchAllFlags(_ context.Context, req *syncv1.FetchAllFlagsR func (s syncHandler) GetMetadata(_ context.Context, _ *syncv1.GetMetadataRequest) ( *syncv1.GetMetadataResponse, error, ) { - metadata, err := structpb.NewStruct(map[string]interface{}{ - "sources": s.mux.SourcesAsMetadata(), - }) + metadataSrc := make(map[string]any) + for k, v := range s.contextValues { + metadataSrc[k] = v + } + if sources := s.mux.SourcesAsMetadata(); sources != "" { + metadataSrc["sources"] = sources + } + + metadata, err := structpb.NewStruct(metadataSrc) if err != nil { s.log.Warn(fmt.Sprintf("error from struct creation: %v", err)) return nil, fmt.Errorf("error constructing metadata response") diff --git a/flagd/pkg/service/flag-sync/sync_service.go b/flagd/pkg/service/flag-sync/sync_service.go index 05cc195e9..b1baefa7d 100644 --- a/flagd/pkg/service/flag-sync/sync_service.go +++ b/flagd/pkg/service/flag-sync/sync_service.go @@ -23,10 +23,11 @@ type ISyncService interface { } type SvcConfigurations struct { - Logger *logger.Logger - Port uint16 - Sources []string - Store *store.Flags + Logger *logger.Logger + Port uint16 + Sources []string + Store *store.Flags + ContextValues map[string]any } type Service struct { @@ -47,8 +48,9 @@ func NewSyncService(cfg SvcConfigurations) (*Service, error) { server := grpc.NewServer() syncv1grpc.RegisterFlagSyncServiceServer(server, &syncHandler{ - mux: mux, - log: l, + mux: mux, + log: l, + contextValues: cfg.ContextValues, }) l.Info(fmt.Sprintf("starting flag sync service on port %d", cfg.Port))