Skip to content

Commit

Permalink
fix: billing updated-at schemantics (#2021)
Browse files Browse the repository at this point in the history
  • Loading branch information
turip authored Jan 6, 2025
1 parent 2cda3e2 commit 9313845
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 105 deletions.
1 change: 1 addition & 0 deletions openmeter/billing/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type ProfileAdapter interface {
GetDefaultProfile(ctx context.Context, input GetDefaultProfileInput) (*BaseProfile, error)
DeleteProfile(ctx context.Context, input DeleteProfileInput) error
UpdateProfile(ctx context.Context, input UpdateProfileAdapterInput) (*BaseProfile, error)
UnsetDefaultProfile(ctx context.Context, input UnsetDefaultProfileInput) error
}

type CustomerOverrideAdapter interface {
Expand Down
18 changes: 14 additions & 4 deletions openmeter/billing/adapter/profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,7 @@ func (a *adapter) DeleteProfile(ctx context.Context, input billing.DeleteProfile

return entutils.TransactingRepoWithNoValue(ctx, a, func(ctx context.Context, tx *adapter) error {
profile, err := tx.GetProfile(ctx, billing.GetProfileInput{
Profile: models.NamespacedID{
Namespace: input.Namespace,
ID: input.ID,
},
Profile: input,
})
if err != nil {
return err
Expand Down Expand Up @@ -254,6 +251,19 @@ func (a *adapter) UpdateProfile(ctx context.Context, input billing.UpdateProfile
})
}

func (a *adapter) UnsetDefaultProfile(ctx context.Context, input billing.UnsetDefaultProfileInput) error {
if err := input.Validate(); err != nil {
return err
}

return entutils.TransactingRepoWithNoValue(ctx, a, func(ctx context.Context, tx *adapter) error {
return tx.db.BillingProfile.Update().
Where(billingprofile.Namespace(input.Namespace)).
SetDefault(false).
Exec(ctx)
})
}

func (a *adapter) updateWorkflowConfig(ctx context.Context, ns string, id string, input billing.WorkflowConfig) (*db.BillingWorkflowConfig, error) {
return a.db.BillingWorkflowConfig.UpdateOneID(id).
Where(billingworkflowconfig.Namespace(ns)).
Expand Down
6 changes: 1 addition & 5 deletions openmeter/billing/customeroverride.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,7 @@ func (i UpdateCustomerOverrideAdapterInput) Validate() error {
return nil
}

type HasCustomerOverrideReferencingProfileAdapterInput genericNamespaceID

func (i HasCustomerOverrideReferencingProfileAdapterInput) Validate() error {
return genericNamespaceID(i).Validate()
}
type HasCustomerOverrideReferencingProfileAdapterInput = ProfileID

type (
UpsertCustomerOverrideAdapterInput = customerentity.CustomerID
Expand Down
2 changes: 1 addition & 1 deletion openmeter/billing/httpdriver/profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (h *handler) GetProfile() GetProfileHandler {
}

return GetProfileRequest{
Profile: models.NamespacedID{
Profile: billing.ProfileID{
Namespace: ns,
ID: params.ID,
},
Expand Down
48 changes: 21 additions & 27 deletions openmeter/billing/profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,12 @@ func (c SupplierContact) Validate() error {
return nil
}

type ProfileID models.NamespacedID

func (p ProfileID) Validate() error {
return models.NamespacedID(p).Validate()
}

type BaseProfile struct {
ID string `json:"id"`
Namespace string `json:"namespace"`
Expand Down Expand Up @@ -226,6 +232,13 @@ func (p BaseProfile) Validate() error {
return nil
}

func (p BaseProfile) ProfileID() ProfileID {
return ProfileID{
Namespace: p.Namespace,
ID: p.ID,
}
}

type Profile struct {
BaseProfile

Expand Down Expand Up @@ -403,25 +416,8 @@ func (i GetDefaultProfileInput) Validate() error {
return nil
}

type genericNamespaceID struct {
Namespace string
ID string
}

func (i genericNamespaceID) Validate() error {
if i.Namespace == "" {
return errors.New("namespace is required")
}

if i.ID == "" {
return errors.New("id is required")
}

return nil
}

type GetProfileInput struct {
Profile models.NamespacedID
Profile ProfileID
Expand ProfileExpand
}

Expand All @@ -437,11 +433,7 @@ func (i GetProfileInput) Validate() error {
return nil
}

type DeleteProfileInput genericNamespaceID

func (i DeleteProfileInput) Validate() error {
return genericNamespaceID(i).Validate()
}
type DeleteProfileInput = ProfileID

type UpdateProfileInput BaseProfile

Expand All @@ -457,6 +449,10 @@ func (i UpdateProfileInput) Validate() error {
return BaseProfile(i).Validate()
}

func (i UpdateProfileInput) ProfileID() ProfileID {
return BaseProfile(i).ProfileID()
}

type UpdateProfileAdapterInput struct {
TargetState BaseProfile
WorkflowConfigID string
Expand All @@ -471,13 +467,11 @@ func (i UpdateProfileAdapterInput) Validate() error {
return fmt.Errorf("id is required")
}

if i.TargetState.UpdatedAt.IsZero() {
return fmt.Errorf("updated at is required")
}

if i.WorkflowConfigID == "" {
return fmt.Errorf("workflow config id is required")
}

return nil
}

type UnsetDefaultProfileInput = ProfileID
82 changes: 35 additions & 47 deletions openmeter/billing/service/profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,24 @@ func (s *Service) CreateProfile(ctx context.Context, input billing.CreateProfile
}

return transaction.Run(ctx, s.adapter, func(ctx context.Context) (*billing.Profile, error) {
// If a profile is already set as default, we need to unset it
// Given that we have multiple constraints let's validate those here for better error reporting
if input.Default {
defaultProfile, err := s.adapter.GetDefaultProfile(ctx, billing.GetDefaultProfileInput{
oldDefaultProfile, err := s.adapter.GetDefaultProfile(ctx, billing.GetDefaultProfileInput{
Namespace: input.Namespace,
})
if err != nil {
return nil, fmt.Errorf("error fetching default profile: %w", err)
}

if defaultProfile != nil {
if err := s.unsetDefaultProfile(ctx, *defaultProfile); err != nil {
return nil, fmt.Errorf("error unsetting default profile: %w", err)
if oldDefaultProfile != nil {
oldDefaultProfile.Default = false

_, err := s.adapter.UpdateProfile(ctx, billing.UpdateProfileAdapterInput{
TargetState: *oldDefaultProfile,
WorkflowConfigID: oldDefaultProfile.WorkflowConfig.ID,
})
if err != nil {
return nil, err
}
}
}
Expand All @@ -52,6 +58,13 @@ func (s *Service) CreateProfile(ctx context.Context, input billing.CreateProfile
Payment: resolvedApps.Payment.Reference,
}

if resolvedApps.Tax.App.GetType() != resolvedApps.Invoicing.App.GetType() ||
resolvedApps.Tax.App.GetType() != resolvedApps.Payment.App.GetType() {
return nil, billing.ValidationError{
Err: fmt.Errorf("all apps must be of the same type"),
}
}

profile, err := s.adapter.CreateProfile(ctx, input)
if err != nil {
return nil, err
Expand Down Expand Up @@ -209,10 +222,7 @@ func (s *Service) DeleteProfile(ctx context.Context, input billing.DeleteProfile

return transaction.RunWithNoValue(ctx, s.adapter, func(ctx context.Context) error {
profile, err := s.adapter.GetProfile(ctx, billing.GetProfileInput{
Profile: models.NamespacedID{
Namespace: input.Namespace,
ID: input.ID,
},
Profile: input,
})
if err != nil {
return err
Expand All @@ -238,7 +248,7 @@ func (s *Service) DeleteProfile(ctx context.Context, input billing.DeleteProfile
}
}

referringCustomerIDs, err := s.adapter.GetCustomerOverrideReferencingProfile(ctx, billing.HasCustomerOverrideReferencingProfileAdapterInput(input))
referringCustomerIDs, err := s.adapter.GetCustomerOverrideReferencingProfile(ctx, input)
if err != nil {
return err
}
Expand Down Expand Up @@ -307,10 +317,7 @@ func (s *Service) UpdateProfile(ctx context.Context, input billing.UpdateProfile

return transaction.Run(ctx, s.adapter, func(ctx context.Context) (*billing.Profile, error) {
profile, err := s.adapter.GetProfile(ctx, billing.GetProfileInput{
Profile: models.NamespacedID{
Namespace: input.Namespace,
ID: input.ID,
},
Profile: input.ProfileID(),
})
if err != nil {
return nil, err
Expand All @@ -328,27 +335,24 @@ func (s *Service) UpdateProfile(ctx context.Context, input billing.UpdateProfile
}
}

// Get the default profile for the namespace if any
defaultProfile, err := s.adapter.GetDefaultProfile(ctx, billing.GetDefaultProfileInput{
Namespace: input.Namespace,
})
if err != nil {
return nil, fmt.Errorf("error fetching default profile: %w", err)
}
if !profile.Default && input.Default {
oldDefaultProfile, err := s.adapter.GetDefaultProfile(ctx, billing.GetDefaultProfileInput{
Namespace: input.Namespace,
})
if err != nil {
return nil, err
}

if defaultProfile != nil {
// If a different profile is being set as default we need to unset the current default
if input.Default && defaultProfile.ID != input.ID {
if err := s.unsetDefaultProfile(ctx, *defaultProfile); err != nil {
return nil, fmt.Errorf("error unsetting default profile: %w", err)
if oldDefaultProfile != nil {
if err := s.adapter.UnsetDefaultProfile(ctx, oldDefaultProfile.ProfileID()); err != nil {
return nil, err
}
}
}

// If the current profile is the default one it cannot be unset
if !input.Default && defaultProfile.ID == input.ID {
return nil, billing.ValidationError{
Err: fmt.Errorf("%w [id=%s]", billing.ErrDefaultProfileCannotBeUnset, input.ID),
}
if profile.Default && !input.Default {
return nil, billing.ValidationError{
Err: fmt.Errorf("%w [id=%s]", billing.ErrDefaultProfileCannotBeUnset, input.ID),
}
}

Expand Down Expand Up @@ -446,19 +450,3 @@ func (s *Service) resolveProfileApps(ctx context.Context, input *billing.BasePro

return &out, nil
}

// unsetDefaultProfile unsets the default profile for the given namespace
func (s *Service) unsetDefaultProfile(ctx context.Context, defaultProfile billing.BaseProfile) error {
profile := defaultProfile
profile.Default = false

_, err := s.adapter.UpdateProfile(ctx, billing.UpdateProfileAdapterInput{
TargetState: profile,
WorkflowConfigID: profile.WorkflowConfig.ID,
})
if err != nil {
return fmt.Errorf("error unsetting default profile: %w", err)
}

return nil
}
44 changes: 23 additions & 21 deletions test/billing/profile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,7 @@ func (s *ProfileTestSuite) TestProfileLifecycle() {
profile := s.createProfileFixture(false)

fetchedProfile, err := s.BillingService.GetProfile(ctx, billing.GetProfileInput{
Profile: models.NamespacedID{
Namespace: profile.Namespace,
ID: profile.ID,
},
Profile: profile.ProfileID(),
Expand: billing.ProfileExpand{
Apps: true,
},
Expand Down Expand Up @@ -169,15 +166,29 @@ func (s *ProfileTestSuite) TestProfileLifecycle() {

// Fetch the profile
fetchedProfile, err := s.BillingService.GetProfile(ctx, billing.GetProfileInput{
Profile: models.NamespacedID{
Namespace: profile.Namespace,
ID: profile.ID,
},
Profile: profile.ProfileID(),
})

require.NoError(t, err)
require.Equal(t, profile.ID, fetchedProfile.ID)
})

t.Run("updating a deleted profile yields an error", func(t *testing.T) {
profile := s.createProfileFixture(false)

// Delete the profile
require.NoError(t, s.BillingService.DeleteProfile(ctx, billing.DeleteProfileInput{
Namespace: profile.Namespace,
ID: profile.ID,
}))

// Update the profile
profile.BaseProfile.AppReferences = nil
_, err := s.BillingService.UpdateProfile(ctx, billing.UpdateProfileInput(profile.BaseProfile))

require.ErrorAs(t, err, &billing.ValidationIssue{})
require.ErrorIs(t, err, billing.ErrProfileAlreadyDeleted)
})
})

s.T().Run("update profile handling", func(t *testing.T) {
Expand Down Expand Up @@ -273,11 +284,8 @@ func (s *ProfileTestSuite) TestProfileFieldSetting() {

// Let's fetch the profile again
fetchedProfile, err := s.BillingService.GetProfile(ctx, billing.GetProfileInput{
Profile: models.NamespacedID{
Namespace: ns,
ID: profile.ID,
},
Expand: billing.ProfileExpandAll,
Profile: profile.ProfileID(),
Expand: billing.ProfileExpandAll,
})

// Sanity check db conversion & fetching
Expand Down Expand Up @@ -374,11 +382,8 @@ func (s *ProfileTestSuite) TestProfileUpdates() {

// Let's fetch the profile again
fetchedProfile, err := s.BillingService.GetProfile(ctx, billing.GetProfileInput{
Profile: models.NamespacedID{
Namespace: ns,
ID: profile.ID,
},
Expand: billing.ProfileExpandAll,
Profile: profile.ProfileID(),
Expand: billing.ProfileExpandAll,
})

// Sanity check db conversion & fetching
Expand All @@ -393,9 +398,6 @@ func (s *ProfileTestSuite) TestProfileUpdates() {
Default: true,
Name: "Awesome Default Profile [update]",
Description: lo.ToPtr("Updated description"),
CreatedAt: profile.CreatedAt,

UpdatedAt: profile.UpdatedAt,

WorkflowConfig: billing.WorkflowConfig{
CreatedAt: profile.WorkflowConfig.CreatedAt,
Expand Down

0 comments on commit 9313845

Please sign in to comment.