Skip to content

Commit

Permalink
[ACM-13056] Updated discovery controller to support service-account a…
Browse files Browse the repository at this point in the history
…uthentication (#395)

* updated discovery controller to support service-account authentication

Signed-off-by: dislbenn <[email protected]>

* reverted return statements for ocm funcs

Signed-off-by: dislbenn <[email protected]>

* updated test case

Signed-off-by: dislbenn <[email protected]>

* updated test case

Signed-off-by: dislbenn <[email protected]>

* added status code check for http request response

Signed-off-by: dislbenn <[email protected]>

* updated comments

Signed-off-by: dislbenn <[email protected]>

* updated service.go

Signed-off-by: dislbenn <[email protected]>

* updated discoveryconfig logic for secret auth

Signed-off-by: dislbenn <[email protected]>

---------

Signed-off-by: dislbenn <[email protected]>
  • Loading branch information
dislbenn authored Aug 26, 2024
1 parent f1c037e commit 35e3159
Show file tree
Hide file tree
Showing 11 changed files with 249 additions and 61 deletions.
4 changes: 2 additions & 2 deletions controllers/discoveredcluster_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,14 +365,14 @@ func (r *DiscoveredClusterReconciler) EnsureAutoImportSecret(ctx context.Context
return ctrl.Result{RequeueAfter: recon.WarningRefreshInterval}, err
}

if apiToken, err := parseUserToken(&existingSecret); err == nil {
if authRequest, err := parseSecretForAuth(&existingSecret); err == nil {
nn := types.NamespacedName{Name: "auto-import-secret", Namespace: dc.Spec.DisplayName}
existingSecret = corev1.Secret{}

if err := r.Get(ctx, nn, &existingSecret, &client.GetOptions{}); apierrors.IsNotFound(err) {
logf.Info("Creating auto-import-secret for managed cluster", "Namespace", nn.Namespace)

s := r.CreateAutoImportSecret(nn, dc.Spec.RHOCMClusterID, apiToken)
s := r.CreateAutoImportSecret(nn, dc.Spec.RHOCMClusterID, authRequest.Token)
if err := r.Create(ctx, s); err != nil {
logf.Error(err, "failed to create auto-import Secret for ManagedCluster", "Name", nn.Name)
return ctrl.Result{RequeueAfter: recon.ErrorRefreshInterval}, err
Expand Down
90 changes: 80 additions & 10 deletions controllers/discoveryconfig_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (

discovery "github.com/stolostron/discovery/api/v1"
"github.com/stolostron/discovery/pkg/ocm"
"github.com/stolostron/discovery/pkg/ocm/auth"
recon "github.com/stolostron/discovery/util/reconciler"
corev1 "k8s.io/api/core/v1"
)
Expand Down Expand Up @@ -130,22 +131,31 @@ func (r *DiscoveryConfigReconciler) updateDiscoveredClusters(ctx context.Context
return err
}

// Update secret to include default authentication method if the field is missing.
if _, found := ocmSecret.Data["auth_method"]; !found {
if err := r.AddDefaultAuthMethodToSecret(ctx, ocmSecret); err != nil {
return err
}
}

// Parse user token from ocm secret.
userToken, err := parseUserToken(ocmSecret)
authRequest, err := parseSecretForAuth(ocmSecret)

if err != nil {
logf.Error(err, "Error parsing token from secret. Deleting all clusters.", "Secret", ocmSecret.GetName())
return r.deleteAllClusters(ctx, config)
}

baseURL := getURLOverride(config)
baseAuthURL := getAuthURLOverride(config)
// Set the baseURL for authentication requests.
authRequest.BaseURL = getURLOverride(config)
authRequest.BaseAuthURL = getAuthURLOverride(config)
filters := config.Spec.Filters

discovered, err := []discovery.DiscoveredCluster{}, nil
if val, ok := os.LookupEnv("UNIT_TEST"); ok && val == "true" {
discovered, err = mockDiscoveredCluster()
} else {
discovered, err = ocm.DiscoverClusters(userToken, baseURL, baseAuthURL, filters)
discovered, err = ocm.DiscoverClusters(authRequest, filters)
}

if err != nil {
Expand Down Expand Up @@ -218,14 +228,74 @@ func (r *DiscoveryConfigReconciler) validateDiscoveryConfigName(reqName string)
return nil
}

// parseUserToken takes a secret cotaining credentials and returns the stored OCM api token.
func parseUserToken(secret *corev1.Secret) (string, error) {
token, ok := secret.Data["ocmAPIToken"]
if !ok {
return "", fmt.Errorf("%s: %w", secret.Name, ErrBadFormat)
/*
parseSecretForAuth parses the given Secret to retrieve authentication credentials.
Depending on the "auth_method" field in the secret, it returns either service account credentials
(client_id, client_secret) or an offline token (ocmAPIToken). If "auth_method" is not set, it
defaults to using the "offline-token" method. Returns an error if the expected fields are missing.
*/
func parseSecretForAuth(secret *corev1.Secret) (auth.AuthRequest, error) {
// Set the default auth_method to "offline-token"
authMethod := "offline-token"

// Check if the "auth_method" key is present in the Secret data
if method, found := secret.Data["auth_method"]; found {
authMethod = string(method)
}

return strings.TrimSuffix(string(token), "\n"), nil
credentials := auth.AuthRequest{
AuthMethod: strings.TrimSuffix(string(authMethod), "\n"), // Set the authentication method
}

// Handle based on the "auth_method" value
switch credentials.AuthMethod {
case "service-account":
// Retrieve client_id and client_secret for service-account auth method
clientID, idOk := secret.Data["client_id"]
clientSecret, secretOk := secret.Data["client_secret"]

if !idOk || !secretOk {
return credentials, fmt.Errorf(
"%s: bad format: secret must contain client_id and client_secret", secret.Name)
}

credentials.ID = strings.TrimSuffix(string(clientID), "\n")
credentials.Secret = strings.TrimSuffix(string(clientSecret), "\n")

case "offline-token":
// Retrive ocmAPIToken for offline-token auth method
token, tokenOk := secret.Data["ocmAPIToken"]
if !tokenOk {
return credentials, fmt.Errorf("%s: bad format: secret must contain ocmAPIToken", secret.Name)
}

credentials.Token = strings.TrimSuffix(string(token), "\n")

default:
return credentials, fmt.Errorf("%s: bad format: unsupported auth_method: %s", secret.Name, authMethod)
}

return credentials, nil
}

func (r *DiscoveryConfigReconciler) AddDefaultAuthMethodToSecret(ctx context.Context, secret *corev1.Secret) error {
// Set the default auth_method to "offline-token"
secret.Data["auth_method"] = []byte("offline-token")

// Check if both client_id and client_secret are present in the secret data
if _, idOK := secret.Data["client_id"]; idOK {
if _, secretOk := secret.Data["client_secret"]; secretOk {
secret.Data["auth_method"] = []byte("service-account")
}
}

// Update the secret
if err := r.Client.Update(ctx, secret); err != nil {
logf.Error(err, "failed to update Secret with default auth_method: 'offline-token'", "Name", secret.GetName())
return err
}

return nil
}

// assignManagedStatus marks clusters in the discovered map as managed if they are in the managed list
Expand Down
74 changes: 67 additions & 7 deletions controllers/discoveryconfig_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,11 @@ var _ = Describe("Discoveryconfig controller", func() {

})

func Test_parseUserToken(t *testing.T) {
func Test_parseSecretForAuth(t *testing.T) {
tests := []struct {
name string
secret *corev1.Secret
want string
want auth.AuthRequest
wantErr bool
}{
{
Expand All @@ -233,10 +233,14 @@ func Test_parseUserToken(t *testing.T) {
Namespace: "test",
},
Data: map[string][]byte{
"auth_method": []byte("offline-token"),
"ocmAPIToken": []byte("dummytoken"),
},
},
want: "dummytoken",
want: auth.AuthRequest{
AuthMethod: "offline-token",
Token: "dummytoken",
},
wantErr: false,
},
{
Expand All @@ -247,19 +251,75 @@ func Test_parseUserToken(t *testing.T) {
Namespace: "test",
},
},
want: "",
want: auth.AuthRequest{
AuthMethod: "offline-token",
Token: "",
},
wantErr: true,
},
{
name: "Dummy service account token",
secret: &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
Namespace: "test",
},
Data: map[string][]byte{
"auth_method": []byte("service-account"),
"client_id": []byte("dc05925d-630b-408b-bfb7-02099be7b789"),
"client_secret": []byte("ZZocNUZWgYSuJHIqK0j0D1mZVdufng6z"),
},
},
want: auth.AuthRequest{
AuthMethod: "service-account",
ID: "dc05925d-630b-408b-bfb7-02099be7b789",
Secret: "ZZocNUZWgYSuJHIqK0j0D1mZVdufng6z",
},
wantErr: false,
},
{
name: "Missing field service account",
secret: &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
Namespace: "test",
},
Data: map[string][]byte{
"auth_method": []byte("service-account"),
"client_id": []byte("dc05925d-630b-408b-bfb7-02099be7b789"),
},
},
want: auth.AuthRequest{
AuthMethod: "service-account",
},
wantErr: true,
},
{
name: "Invalid authentication method",
secret: &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
Namespace: "test",
},
Data: map[string][]byte{
"auth_method": []byte("invalid-method"),
},
},
want: auth.AuthRequest{
AuthMethod: "invalid-method",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := parseUserToken(tt.secret)
got, err := parseSecretForAuth(tt.secret)
if (err != nil) != tt.wantErr {
t.Errorf("parseUserToken() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("parseSecretForAuth() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("parseUserToken() = %v, want %v", got, tt.want)
t.Errorf("parseSecretForAuth() = %v, want %v", got, tt.want)
}
})
}
Expand Down
8 changes: 6 additions & 2 deletions pkg/ocm/auth/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ type AuthTokenResponse struct {
}

type AuthRequest struct {
BaseURL string
Token string
AuthMethod string `json:"auth_method,omitempty" yaml:"auth_method,omitempty"`
BaseURL string `json:"base_url,omitempty" yaml:"base_url,omitempty"`
BaseAuthURL string `json:"base_auth_url,omitempty" yaml:"base_auth_url,omitempty"`
ID string `json:"client_id,omitempty" yaml:"client_id,omitempty"`
Secret string `json:"client_secret,omitempty" yaml:"client_secret,omitempty"`
Token string `json:"ocmAPIToken,omitempty" yaml:"ocmAPIToken,omitempty"`
}

type AuthError struct {
Expand Down
20 changes: 16 additions & 4 deletions pkg/ocm/auth/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,22 @@ type authProvider struct{}

func (a *authProvider) GetToken(request AuthRequest) (retRes *AuthTokenResponse, retErr *AuthError) {
postUrl := fmt.Sprintf(authEndpoint, request.BaseURL)
data := url.Values{
"grant_type": {"refresh_token"},
"client_id": {"cloud-services"},
"refresh_token": {request.Token},

var data url.Values
switch request.AuthMethod {
case "service-account":
data = url.Values{
"grant_type": {"client_credentials"},
"client_id": {request.ID},
"client_secret": {request.Secret},
}

default:
data = url.Values{
"grant_type": {"refresh_token"},
"client_id": {"cloud-services"},
"refresh_token": {request.Token},
}
}

response, err := httpClient.Post(postUrl, data)
Expand Down
4 changes: 2 additions & 2 deletions pkg/ocm/auth/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestProviderGetTokenNoError(t *testing.T) {
assert.EqualValues(t, "new_access_token", response)
}

// recieved an AuthTokenResponse but it's missing and `access_token`
// received an AuthTokenResponse but it's missing and `access_token`
func TestGetTokenMissingAccessToken(t *testing.T) {
getTokenFunc = func(request AuthRequest) (*AuthTokenResponse, *AuthError) {
return &AuthTokenResponse{}, nil
Expand All @@ -50,7 +50,7 @@ func TestGetTokenMissingAccessToken(t *testing.T) {
assert.EqualValues(t, "", response)
}

// recieved an error caused by unmarshalling rather than from the API
// received an error caused by unmarshalling rather than from the API
func TestGetTokenInvalidErrorResponse(t *testing.T) {
getTokenFunc = func(request AuthRequest) (*AuthTokenResponse, *AuthError) {
return nil, &AuthError{
Expand Down
9 changes: 3 additions & 6 deletions pkg/ocm/ocm.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,8 @@ import (

// DiscoverClusters returns a list of DiscoveredClusters found in both the accounts_mgmt and
// accounts_mgmt apis with the given filters
func DiscoverClusters(token string, baseURL string, baseAuthURL string, filters discovery.Filter) ([]discovery.DiscoveredCluster, error) {
func DiscoverClusters(authRequest auth.AuthRequest, filters discovery.Filter) ([]discovery.DiscoveredCluster, error) {
// Request ephemeral access token with user token. This will be used for OCM requests
authRequest := auth.AuthRequest{
Token: token,
BaseURL: baseAuthURL,
}
accessToken, err := auth.AuthClient.GetToken(authRequest)
if err != nil {
return nil, err
Expand All @@ -29,9 +25,10 @@ func DiscoverClusters(token string, baseURL string, baseAuthURL string, filters
// Get subscriptions from accounts_mgmt api
subscriptionRequestConfig := subscription.SubscriptionRequest{
Token: accessToken,
BaseURL: baseURL,
BaseURL: authRequest.BaseURL,
Filter: filters,
}

subscriptionClient := subscription.SubscriptionClientGenerator.NewClient(subscriptionRequestConfig)
subscriptions, err := subscriptionClient.GetSubscriptions()
if err != nil {
Expand Down
28 changes: 10 additions & 18 deletions pkg/ocm/ocm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,11 @@ func subscriptionResponse(testdata string) func() ([]subscription.Subscription,
}

func TestDiscoverClusters(t *testing.T) {
type args struct {
token string
baseURL string
baseAuthURL string
filters discovery.Filter
}
tests := []struct {
name string
authfunc func(auth.AuthRequest) (string, error)
subscriptionFunc func() ([]subscription.Subscription, error)
args args
authRequest auth.AuthRequest
want int
wantErr bool
}{
Expand All @@ -76,11 +70,10 @@ func TestDiscoverClusters(t *testing.T) {
},
// this mock returns 3 subscriptions read from mock_subscriptions.json
subscriptionFunc: subscriptionResponse("testdata/1_mock_subscription.json"),
args: args{
token: "test",
baseURL: "test",
baseAuthURL: "test",
filters: discovery.Filter{},
authRequest: auth.AuthRequest{
Token: "test",
BaseURL: "test",
BaseAuthURL: "test",
},
want: 1,
wantErr: false,
Expand All @@ -93,11 +86,10 @@ func TestDiscoverClusters(t *testing.T) {
},
// this mock returns 3 subscriptions read from mock_subscriptions.json
subscriptionFunc: subscriptionResponse("testdata/3_mock_subscriptions.json"),
args: args{
token: "test",
baseURL: "test",
baseAuthURL: "test",
filters: discovery.Filter{},
authRequest: auth.AuthRequest{
Token: "test",
BaseURL: "test",
BaseAuthURL: "test",
},
want: 2,
wantErr: false,
Expand All @@ -112,7 +104,7 @@ func TestDiscoverClusters(t *testing.T) {
// TODO: Running `getSubscriptionsFunc` should yield the subscriptions to test against, but we don't do this
getSubscriptionsFunc = tt.subscriptionFunc

got, err := DiscoverClusters(tt.args.token, tt.args.baseURL, tt.args.baseAuthURL, tt.args.filters)
got, err := DiscoverClusters(tt.authRequest, discovery.Filter{})
if (err != nil) != tt.wantErr {
t.Errorf("DiscoverClusters() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down
Loading

0 comments on commit 35e3159

Please sign in to comment.