diff --git a/service/servicemessagehandler/servicemessagehandler.go b/service/servicemessagehandler/servicemessagehandler.go index f4312dd..7532da4 100644 --- a/service/servicemessagehandler/servicemessagehandler.go +++ b/service/servicemessagehandler/servicemessagehandler.go @@ -63,8 +63,8 @@ func (rd *ServiceMessageHandler) sendMessage(ctx context.Context, serviceName st return nil } -func (rd *ServiceMessageHandler) DispatchTraTVerificationRule(ctx context.Context, serviceName string, namespace string, verificationEndpointRule *v1alpha1.TraTVerificationRule, versionNumber int64) error { - jsonData, err := json.Marshal(verificationEndpointRule) +func (rd *ServiceMessageHandler) DispatchTraTVerificationRule(ctx context.Context, serviceName string, namespace string, serviceTraTVerificationRules *v1alpha1.ServiceTraTVerificationRules, versionNumber int64) error { + jsonData, err := json.Marshal(serviceTraTVerificationRules) if err != nil { return fmt.Errorf("error marshaling verification trat rule: %w", err) } diff --git a/service/tratteriacontroller/controller/controller.go b/service/tratteriacontroller/controller/controller.go index 828f184..2f4497c 100644 --- a/service/tratteriacontroller/controller/controller.go +++ b/service/tratteriacontroller/controller/controller.go @@ -371,14 +371,14 @@ func (c *Controller) GetActiveVerificationRules(serviceName string, namespace st return nil, 0, err } - traTVerificationRules, err := c.GetActiveTraTsVerificationRules(serviceName, namespace) + traTsVerificationRules, err := c.GetActiveTraTsVerificationRules(serviceName, namespace) if err != nil { return nil, 0, err } return &tratteria1alpha1.VerificationRules{ TratteriaConfigVerificationRule: tratteriaConfigVerificationRule, - TraTsVerificationRules: traTVerificationRules, + TraTsVerificationRules: traTsVerificationRules, }, activeRuleVersionNumber, nil diff --git a/service/tratteriacontroller/controller/tratcontroller.go b/service/tratteriacontroller/controller/tratcontroller.go index 6f0958c..3a28768 100644 --- a/service/tratteriacontroller/controller/tratcontroller.go +++ b/service/tratteriacontroller/controller/tratcontroller.go @@ -14,7 +14,7 @@ import ( ) func (c *Controller) handleTraTUpsert(ctx context.Context, newTraT *tratteria1alpha1.TraT, versionNumber int64) error { - verificationEndpointRules, err := newTraT.GetTraTVerificationRules() + servicestraTVerificationRules, err := newTraT.GetTraTVerificationRules() if err != nil { messagedErr := fmt.Errorf("error retrieving verification rules from %s trat: %w", newTraT.Name, err) @@ -28,8 +28,8 @@ func (c *Controller) handleTraTUpsert(ctx context.Context, newTraT *tratteria1al } // TODO: Implement parallel dispatching of rules using goroutines - for service, serviceVerificationRule := range verificationEndpointRules { - err := c.serviceMessageHandler.DispatchTraTVerificationRule(ctx, service, newTraT.Namespace, serviceVerificationRule, versionNumber) + for service, serviceTraTVerificationRules := range servicestraTVerificationRules { + err := c.serviceMessageHandler.DispatchTraTVerificationRule(ctx, service, newTraT.Namespace, serviceTraTVerificationRules, versionNumber) if err != nil { messagedErr := fmt.Errorf("error dispatching %s trat verification rule to %s service: %w", newTraT.Name, service, err) @@ -83,16 +83,21 @@ func (c *Controller) handleTraTUpsert(ctx context.Context, newTraT *tratteria1al func (c *Controller) handleTraTUpdation(ctx context.Context, newTraT *tratteria1alpha1.TraT, oldTraT *tratteria1alpha1.TraT, versionNumber int64) error { // First, handle any service removals from the TraT newServices := make(map[string]bool) - for _, serviceSpec := range newTraT.Spec.Services { - newServices[serviceSpec.Name] = true + for _, newServiceSpec := range newTraT.Spec.Services { + newServices[newServiceSpec.Name] = true } + oldServices := make(map[string]bool) for _, oldServiceSpec := range oldTraT.Spec.Services { - if !newServices[oldServiceSpec.Name] { + oldServices[oldServiceSpec.Name] = true + } + + for oldService := range oldServices { + if !newServices[oldService] { // Service was removed, so remove the TraT from it - err := c.serviceMessageHandler.DeleteTraT(ctx, oldServiceSpec.Name, oldTraT.Namespace, oldTraT.Name, versionNumber) + err := c.serviceMessageHandler.DeleteTraT(ctx, oldService, oldTraT.Namespace, oldTraT.Name, versionNumber) if err != nil { - messagedErr := fmt.Errorf("error deleting %s trat from %s service: %w", oldTraT.Name, oldServiceSpec.Name, err) + messagedErr := fmt.Errorf("error deleting %s trat from %s service: %w", oldTraT.Name, oldService, err) c.recorder.Event(newTraT, corev1.EventTypeWarning, "error", messagedErr.Error()) @@ -110,23 +115,24 @@ func (c *Controller) handleTraTUpdation(ctx context.Context, newTraT *tratteria1 } func (c *Controller) handleTraTDeletion(ctx context.Context, oldTraT *tratteria1alpha1.TraT, versionNumber int64) error { - // TODO: Implement parallel requests using goroutines + services := make(map[string]bool) + for _, serviceSpec := range oldTraT.Spec.Services { - err := c.serviceMessageHandler.DeleteTraT(ctx, serviceSpec.Name, oldTraT.Namespace, oldTraT.Name, versionNumber) + services[serviceSpec.Name] = true + } + + services[common.TRATTERIA_SERVICE_NAME] = true + + // TODO: Implement parallel requests using goroutines + for service := range services { + err := c.serviceMessageHandler.DeleteTraT(ctx, service, oldTraT.Namespace, oldTraT.Name, versionNumber) if err != nil { - messagedErr := fmt.Errorf("error deleting %s trat from %s service: %w", oldTraT.Name, serviceSpec.Name, err) + messagedErr := fmt.Errorf("error deleting %s trat from %s service: %w", oldTraT.Name, service, err) return messagedErr } } - err := c.serviceMessageHandler.DeleteTraT(ctx, common.TRATTERIA_SERVICE_NAME, oldTraT.Namespace, oldTraT.Name, versionNumber) - if err != nil { - messagedErr := fmt.Errorf("error deleting %s trat from %s service: %w", oldTraT.Name, common.TRATTERIA_SERVICE_NAME, err) - - return messagedErr - } - return nil } @@ -161,7 +167,7 @@ func (c *Controller) updateSuccessTratStatus(ctx context.Context, trat *tratteri return updateErr } -func (c *Controller) GetActiveTraTsVerificationRules(serviceName string, namespace string) (map[string]*tratteria1alpha1.TraTVerificationRule, error) { +func (c *Controller) GetActiveTraTsVerificationRules(serviceName string, namespace string) (map[string]*tratteria1alpha1.ServiceTraTVerificationRules, error) { traTs, err := c.traTsLister.TraTs(namespace).List(labels.Everything()) if err != nil { c.logger.Error("Failed to list TraTs in namespace.", zap.String("namespace", namespace), zap.Error(err)) @@ -169,19 +175,20 @@ func (c *Controller) GetActiveTraTsVerificationRules(serviceName string, namespa return nil, err } - traTsVerificationRules := make(map[string]*tratteria1alpha1.TraTVerificationRule) + serviceTraTsVerificationRules := make(map[string]*tratteria1alpha1.ServiceTraTVerificationRules) for _, traT := range traTs { - traTVerificationRule, err := traT.GetTraTVerificationRules() + traTVerificationRules, err := traT.GetTraTVerificationRules() if err != nil { return nil, err } - if serviceTraTVerificationRule := traTVerificationRule[serviceName]; serviceTraTVerificationRule != nil { - traTsVerificationRules[traT.Name] = serviceTraTVerificationRule + + if serviceTraTVerificationRules := traTVerificationRules[serviceName]; serviceTraTVerificationRules != nil { + serviceTraTsVerificationRules[traT.Name] = serviceTraTVerificationRules } } - return traTsVerificationRules, nil + return serviceTraTsVerificationRules, nil } func (c *Controller) GetActiveTraTsGenerationRules(namespace string) (map[string]*tratteria1alpha1.TraTGenerationRule, error) { diff --git a/service/tratteriacontroller/pkg/apis/tratteria/v1alpha1/types.go b/service/tratteriacontroller/pkg/apis/tratteria/v1alpha1/types.go index 331d675..bc54654 100644 --- a/service/tratteriacontroller/pkg/apis/tratteria/v1alpha1/types.go +++ b/service/tratteriacontroller/pkg/apis/tratteria/v1alpha1/types.go @@ -68,6 +68,11 @@ type TraTVerificationRule struct { AzdMapping AzdMapping `json:"azdmapping,omitempty"` } +type ServiceTraTVerificationRules struct { + TraTName string + TraTVerificationRules []*TraTVerificationRule +} + type TraTGenerationRule struct { TraTName string `json:"traTName"` Endpoint string `json:"endpoint"` @@ -76,10 +81,10 @@ type TraTGenerationRule struct { AzdMapping AzdMapping `json:"azdmapping,omitempty"` } -func (traT *TraT) GetTraTVerificationRules() (map[string]*TraTVerificationRule, error) { - verificationRules := make(map[string]*TraTVerificationRule) - - // TODO: do basic check and return err if failed +// constructs TraT verification for each service present in the call chain +// a single service can have multiple different APIs present in the call chain, so it return the map of list of TraTVerificationRule +func (traT *TraT) GetTraTVerificationRules() (map[string]*ServiceTraTVerificationRules, error) { + servicesTraTVerificationRules := make(map[string]*ServiceTraTVerificationRules) for _, serviceSpec := range traT.Spec.Services { endpoint := traT.Spec.Endpoint @@ -98,25 +103,31 @@ func (traT *TraT) GetTraTVerificationRules() (map[string]*TraTVerificationRule, azdMapping = serviceSpec.AzdMapping } - verificationRules[serviceSpec.Name] = &TraTVerificationRule{ - TraTName: traT.Name, - Endpoint: endpoint, - Method: method, - Purp: traT.Spec.Purp, - AzdMapping: azdMapping, + if servicesTraTVerificationRules[serviceSpec.Name] == nil { + servicesTraTVerificationRules[serviceSpec.Name] = &ServiceTraTVerificationRules{ + TraTName: traT.Name, + } } + servicesTraTVerificationRules[serviceSpec.Name].TraTVerificationRules = append( + servicesTraTVerificationRules[serviceSpec.Name].TraTVerificationRules, + &TraTVerificationRule{ + TraTName: traT.Name, + Endpoint: endpoint, + Method: method, + Purp: traT.Spec.Purp, + AzdMapping: azdMapping, + }) } - if len(verificationRules) == 0 { + if len(servicesTraTVerificationRules) == 0 { return nil, fmt.Errorf("%w: verification rules for %s trat", tconfigderrors.ErrNotFound, traT.Name) } - return verificationRules, nil + return servicesTraTVerificationRules, nil } func (traT *TraT) GetTraTGenerationRule() (*TraTGenerationRule, error) { - // TODO: do basic check and return err if failed return &TraTGenerationRule{ TraTName: traT.Name, @@ -218,8 +229,8 @@ func (tratteriaConfig *TratteriaConfig) GetTratteriaConfigGenerationRule() (*Tra } type VerificationRules struct { - TratteriaConfigVerificationRule *TratteriaConfigVerificationRule `json:"tratteriaConfigVerificationRule"` - TraTsVerificationRules map[string]*TraTVerificationRule `json:"traTsVerificationRules"` + TratteriaConfigVerificationRule *TratteriaConfigVerificationRule `json:"tratteriaConfigVerificationRule"` + TraTsVerificationRules map[string]*ServiceTraTVerificationRules `json:"traTsVerificationRules"` } func (verificationRules *VerificationRules) ComputeStableHash() (string, error) { diff --git a/service/tratteriacontroller/pkg/apis/tratteria/v1alpha1/zz_generated.deepcopy.go b/service/tratteriacontroller/pkg/apis/tratteria/v1alpha1/zz_generated.deepcopy.go index 1433c15..81e5ee2 100644 --- a/service/tratteriacontroller/pkg/apis/tratteria/v1alpha1/zz_generated.deepcopy.go +++ b/service/tratteriacontroller/pkg/apis/tratteria/v1alpha1/zz_generated.deepcopy.go @@ -188,6 +188,33 @@ func (in *ServiceSpec) DeepCopy() *ServiceSpec { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ServiceTraTVerificationRules) DeepCopyInto(out *ServiceTraTVerificationRules) { + *out = *in + if in.TraTVerificationRules != nil { + in, out := &in.TraTVerificationRules, &out.TraTVerificationRules + *out = make([]*TraTVerificationRule, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(TraTVerificationRule) + (*in).DeepCopyInto(*out) + } + } + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ServiceTraTVerificationRules. +func (in *ServiceTraTVerificationRules) DeepCopy() *ServiceTraTVerificationRules { + if in == nil { + return nil + } + out := new(ServiceTraTVerificationRules) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *SubjectTokens) DeepCopyInto(out *SubjectTokens) { *out = *in @@ -534,14 +561,14 @@ func (in *VerificationRules) DeepCopyInto(out *VerificationRules) { } if in.TraTsVerificationRules != nil { in, out := &in.TraTsVerificationRules, &out.TraTsVerificationRules - *out = make(map[string]*TraTVerificationRule, len(*in)) + *out = make(map[string]*ServiceTraTVerificationRules, len(*in)) for key, val := range *in { - var outVal *TraTVerificationRule + var outVal *ServiceTraTVerificationRules if val == nil { (*out)[key] = nil } else { in, out := &val, &outVal - *out = new(TraTVerificationRule) + *out = new(ServiceTraTVerificationRules) (*in).DeepCopyInto(*out) } (*out)[key] = outVal