diff --git a/internal/entitlement/httpdriver/entitlement.go b/internal/entitlement/httpdriver/entitlement.go index 115cc4846..521d92b77 100644 --- a/internal/entitlement/httpdriver/entitlement.go +++ b/internal/entitlement/httpdriver/entitlement.go @@ -8,9 +8,6 @@ import ( "github.com/openmeterio/openmeter/api" "github.com/openmeterio/openmeter/internal/entitlement" - booleanentitlement "github.com/openmeterio/openmeter/internal/entitlement/boolean" - meteredentitlement "github.com/openmeterio/openmeter/internal/entitlement/metered" - staticentitlement "github.com/openmeterio/openmeter/internal/entitlement/static" "github.com/openmeterio/openmeter/internal/namespace/namespacedriver" "github.com/openmeterio/openmeter/pkg/clock" "github.com/openmeterio/openmeter/pkg/convert" @@ -181,37 +178,11 @@ func (h *entitlementHandler) GetEntitlementValue() GetEntitlementValueHandler { }, nil }, func(ctx context.Context, request GetEntitlementValueHandlerRequest) (api.EntitlementValue, error) { - entitlement, err := h.connector.GetEntitlementValue(ctx, request.Namespace, request.SubjectKey, request.EntitlementIdOrFeatureKey, request.At) + entitlementValue, err := h.connector.GetEntitlementValue(ctx, request.Namespace, request.SubjectKey, request.EntitlementIdOrFeatureKey, request.At) if err != nil { return api.EntitlementValue{}, err } - - switch ent := entitlement.(type) { - case *meteredentitlement.MeteredEntitlementValue: - return api.EntitlementValue{ - HasAccess: convert.ToPointer(ent.HasAccess()), - Balance: &ent.Balance, - Usage: &ent.UsageInPeriod, - Overage: &ent.Overage, - }, nil - case *staticentitlement.StaticEntitlementValue: - var config *string - if len(ent.Config) > 0 { - config = convert.ToPointer(string(ent.Config)) - } - - return api.EntitlementValue{ - HasAccess: convert.ToPointer(ent.HasAccess()), - Config: config, - }, nil - case *booleanentitlement.BooleanEntitlementValue: - return api.EntitlementValue{ - HasAccess: convert.ToPointer(ent.HasAccess()), - }, nil - default: - return api.EntitlementValue{}, errors.New("unknown entitlement type") - } - + return MapEntitlementValueToAPI(entitlementValue) }, commonhttp.JSONResponseEncoder[api.EntitlementValue], httptransport.AppendOptions( diff --git a/internal/entitlement/httpdriver/parser.go b/internal/entitlement/httpdriver/parser.go index 84d58712c..86dca321f 100644 --- a/internal/entitlement/httpdriver/parser.go +++ b/internal/entitlement/httpdriver/parser.go @@ -1,6 +1,7 @@ package httpdriver import ( + "errors" "fmt" "github.com/openmeterio/openmeter/api" @@ -133,6 +134,34 @@ func (p parser) ToAPIGeneric(e *entitlement.Entitlement) (*api.Entitlement, erro } } +func MapEntitlementValueToAPI(entitlementValue entitlement.EntitlementValue) (api.EntitlementValue, error) { + switch ent := entitlementValue.(type) { + case *meteredentitlement.MeteredEntitlementValue: + return api.EntitlementValue{ + HasAccess: convert.ToPointer(ent.HasAccess()), + Balance: &ent.Balance, + Usage: &ent.UsageInPeriod, + Overage: &ent.Overage, + }, nil + case *staticentitlement.StaticEntitlementValue: + var config *string + if len(ent.Config) > 0 { + config = convert.ToPointer(string(ent.Config)) + } + + return api.EntitlementValue{ + HasAccess: convert.ToPointer(ent.HasAccess()), + Config: config, + }, nil + case *booleanentitlement.BooleanEntitlementValue: + return api.EntitlementValue{ + HasAccess: convert.ToPointer(ent.HasAccess()), + }, nil + default: + return api.EntitlementValue{}, errors.New("unknown entitlement type") + } +} + func mapUsagePeriod(u *entitlement.UsagePeriod) *api.RecurringPeriod { if u == nil { return nil diff --git a/openmeter/entitlement/httpdriver/handlers.go b/openmeter/entitlement/httpdriver/handlers.go index bcc939718..8130b9f4a 100644 --- a/openmeter/entitlement/httpdriver/handlers.go +++ b/openmeter/entitlement/httpdriver/handlers.go @@ -1,6 +1,7 @@ package httpdriver import ( + "github.com/openmeterio/openmeter/api" "github.com/openmeterio/openmeter/internal/entitlement/httpdriver" "github.com/openmeterio/openmeter/openmeter/entitlement" meteredentitlement "github.com/openmeterio/openmeter/openmeter/entitlement/metered" @@ -38,3 +39,7 @@ func NewMeteredEntitlementHandler( ) MeteredEntitlementHandler { return httpdriver.NewMeteredEntitlementHandler(entitlementConnector, meteredEntitlementConnector, namespaceDecoder, options...) } + +func MapEntitlementValueToAPI(entitlementValue entitlement.EntitlementValue) (api.EntitlementValue, error) { + return httpdriver.MapEntitlementValueToAPI(entitlementValue) +}