Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ACM-13056] Updated discovery controller to support service-account authentication #395

Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions controllers/discoveredcluster_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,14 +365,14 @@ func (r *DiscoveredClusterReconciler) EnsureAutoImportSecret(ctx context.Context
return ctrl.Result{RequeueAfter: recon.WarningRefreshInterval}, err
}

if apiToken, err := parseUserToken(&existingSecret); err == nil {
if authRequest, err := parseSecretForAuth(&existingSecret); err == nil {
nn := types.NamespacedName{Name: "auto-import-secret", Namespace: dc.Spec.DisplayName}
existingSecret = corev1.Secret{}

if err := r.Get(ctx, nn, &existingSecret, &client.GetOptions{}); apierrors.IsNotFound(err) {
logf.Info("Creating auto-import-secret for managed cluster", "Namespace", nn.Namespace)

s := r.CreateAutoImportSecret(nn, dc.Spec.RHOCMClusterID, apiToken)
s := r.CreateAutoImportSecret(nn, dc.Spec.RHOCMClusterID, authRequest.Token)
if err := r.Create(ctx, s); err != nil {
logf.Error(err, "failed to create auto-import Secret for ManagedCluster", "Name", nn.Name)
return ctrl.Result{RequeueAfter: recon.ErrorRefreshInterval}, err
Expand Down
63 changes: 53 additions & 10 deletions controllers/discoveryconfig_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (

discovery "github.com/stolostron/discovery/api/v1"
"github.com/stolostron/discovery/pkg/ocm"
"github.com/stolostron/discovery/pkg/ocm/auth"
recon "github.com/stolostron/discovery/util/reconciler"
corev1 "k8s.io/api/core/v1"
)
Expand Down Expand Up @@ -131,21 +132,23 @@ func (r *DiscoveryConfigReconciler) updateDiscoveredClusters(ctx context.Context
}

// Parse user token from ocm secret.
userToken, err := parseUserToken(ocmSecret)
authRequest, err := parseSecretForAuth(ocmSecret)

if err != nil {
logf.Error(err, "Error parsing token from secret. Deleting all clusters.", "Secret", ocmSecret.GetName())
return r.deleteAllClusters(ctx, config)
}

baseURL := getURLOverride(config)
baseAuthURL := getAuthURLOverride(config)
// Set the baseURL for authentication requests.
authRequest.BaseURL = getURLOverride(config)
authRequest.BaseAuthURL = getAuthURLOverride(config)
filters := config.Spec.Filters

discovered, err := []discovery.DiscoveredCluster{}, nil
if val, ok := os.LookupEnv("UNIT_TEST"); ok && val == "true" {
discovered, err = mockDiscoveredCluster()
} else {
discovered, err = ocm.DiscoverClusters(userToken, baseURL, baseAuthURL, filters)
discovered, err = ocm.DiscoverClusters(authRequest, filters)
}

if err != nil {
Expand Down Expand Up @@ -218,14 +221,54 @@ func (r *DiscoveryConfigReconciler) validateDiscoveryConfigName(reqName string)
return nil
}

// parseUserToken takes a secret cotaining credentials and returns the stored OCM api token.
func parseUserToken(secret *corev1.Secret) (string, error) {
token, ok := secret.Data["ocmAPIToken"]
if !ok {
return "", fmt.Errorf("%s: %w", secret.Name, ErrBadFormat)
/*
parseSecretForAuth parses the given Secret to retrieve authentication credentials.
Depending on the "auth_method" field in the secret, it returns either service account credentials
(client_id, client_secret) or an offline token (ocmAPIToken). If "auth_method" is not set, it
defaults to using the "offline-token" method. Returns an error if the expected fields are missing.
*/
func parseSecretForAuth(secret *corev1.Secret) (auth.AuthRequest, error) {
// Set the default auth_method to "offline-token"
authMethod := "offline-token"

// Check if the "auth_method" key is present in the Secret data
if method, found := secret.Data["auth_method"]; found {
authMethod = string(method)
}

credentials := auth.AuthRequest{
AuthMethod: strings.TrimSuffix(string(authMethod), "\n"), // Set the authentication method
}

// Handle based on the "auth_method" value
switch credentials.AuthMethod {
case "service-account":
// Retrieve client_id and client_secret for service-account auth method
clientID, idOk := secret.Data["client_id"]
clientSecret, secretOk := secret.Data["client_secret"]

if !idOk || !secretOk {
return credentials, fmt.Errorf(
"%s: bad format: secret must contain client_id and client_secret", secret.Name)
}

credentials.ID = strings.TrimSuffix(string(clientID), "\n")
credentials.Secret = strings.TrimSuffix(string(clientSecret), "\n")

case "offline-token":
// Retrive ocmAPIToken for offline-token auth method
token, tokenOk := secret.Data["ocmAPIToken"]
if !tokenOk {
return credentials, fmt.Errorf("%s: bad format: secret must contain ocmAPIToken", secret.Name)
}

credentials.Token = strings.TrimSuffix(string(token), "\n")

default:
return credentials, fmt.Errorf("%s: bad format: unsupported auth_method: %s", secret.Name, authMethod)
}

return strings.TrimSuffix(string(token), "\n"), nil
return credentials, nil
}

// assignManagedStatus marks clusters in the discovered map as managed if they are in the managed list
Expand Down
74 changes: 67 additions & 7 deletions controllers/discoveryconfig_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,11 @@ var _ = Describe("Discoveryconfig controller", func() {

})

func Test_parseUserToken(t *testing.T) {
func Test_parseSecretForAuth(t *testing.T) {
tests := []struct {
name string
secret *corev1.Secret
want string
want auth.AuthRequest
wantErr bool
}{
{
Expand All @@ -233,10 +233,14 @@ func Test_parseUserToken(t *testing.T) {
Namespace: "test",
},
Data: map[string][]byte{
"auth_method": []byte("offline-token"),
"ocmAPIToken": []byte("dummytoken"),
},
},
want: "dummytoken",
want: auth.AuthRequest{
AuthMethod: "offline-token",
Token: "dummytoken",
},
wantErr: false,
},
{
Expand All @@ -247,19 +251,75 @@ func Test_parseUserToken(t *testing.T) {
Namespace: "test",
},
},
want: "",
want: auth.AuthRequest{
AuthMethod: "offline-token",
Token: "",
},
wantErr: true,
},
{
name: "Dummy service account token",
secret: &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
Namespace: "test",
},
Data: map[string][]byte{
"auth_method": []byte("service-account"),
"client_id": []byte("dc05925d-630b-408b-bfb7-02099be7b789"),
"client_secret": []byte("ZZocNUZWgYSuJHIqK0j0D1mZVdufng6z"),
},
},
want: auth.AuthRequest{
AuthMethod: "service-account",
ID: "dc05925d-630b-408b-bfb7-02099be7b789",
Secret: "ZZocNUZWgYSuJHIqK0j0D1mZVdufng6z",
},
wantErr: false,
},
{
name: "Missing field service account",
secret: &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
Namespace: "test",
},
Data: map[string][]byte{
"auth_method": []byte("service-account"),
"client_id": []byte("dc05925d-630b-408b-bfb7-02099be7b789"),
},
},
want: auth.AuthRequest{
AuthMethod: "service-account",
},
wantErr: true,
},
{
name: "Invalid authentication method",
secret: &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
Namespace: "test",
},
Data: map[string][]byte{
"auth_method": []byte("invalid-method"),
},
},
want: auth.AuthRequest{
AuthMethod: "invalid-method",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := parseUserToken(tt.secret)
got, err := parseSecretForAuth(tt.secret)
if (err != nil) != tt.wantErr {
t.Errorf("parseUserToken() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("parseSecretForAuth() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("parseUserToken() = %v, want %v", got, tt.want)
t.Errorf("parseSecretForAuth() = %v, want %v", got, tt.want)
}
})
}
Expand Down
8 changes: 6 additions & 2 deletions pkg/ocm/auth/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ type AuthTokenResponse struct {
}

type AuthRequest struct {
BaseURL string
Token string
AuthMethod string `json:"auth_method,omitempty" yaml:"auth_method,omitempty"`
BaseURL string `json:"base_url,omitempty" yaml:"base_url,omitempty"`
BaseAuthURL string `json:"base_auth_url,omitempty" yaml:"base_auth_url,omitempty"`
ID string `json:"client_id,omitempty" yaml:"client_id,omitempty"`
Secret string `json:"client_secret,omitempty" yaml:"client_secret,omitempty"`
Token string `json:"ocmAPIToken,omitempty" yaml:"ocmAPIToken,omitempty"`
}

type AuthError struct {
Expand Down
20 changes: 16 additions & 4 deletions pkg/ocm/auth/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,22 @@ type authProvider struct{}

func (a *authProvider) GetToken(request AuthRequest) (retRes *AuthTokenResponse, retErr *AuthError) {
postUrl := fmt.Sprintf(authEndpoint, request.BaseURL)
data := url.Values{
"grant_type": {"refresh_token"},
"client_id": {"cloud-services"},
"refresh_token": {request.Token},

var data url.Values
switch request.AuthMethod {
case "service-account":
data = url.Values{
"grant_type": {"client_credentials"},
"client_id": {request.ID},
"client_secret": {request.Secret},
}

default:
data = url.Values{
"grant_type": {"refresh_token"},
"client_id": {"cloud-services"},
"refresh_token": {request.Token},
}
}

response, err := httpClient.Post(postUrl, data)
Expand Down
4 changes: 2 additions & 2 deletions pkg/ocm/auth/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestProviderGetTokenNoError(t *testing.T) {
assert.EqualValues(t, "new_access_token", response)
}

// recieved an AuthTokenResponse but it's missing and `access_token`
// received an AuthTokenResponse but it's missing and `access_token`
func TestGetTokenMissingAccessToken(t *testing.T) {
getTokenFunc = func(request AuthRequest) (*AuthTokenResponse, *AuthError) {
return &AuthTokenResponse{}, nil
Expand All @@ -50,7 +50,7 @@ func TestGetTokenMissingAccessToken(t *testing.T) {
assert.EqualValues(t, "", response)
}

// recieved an error caused by unmarshalling rather than from the API
// received an error caused by unmarshalling rather than from the API
func TestGetTokenInvalidErrorResponse(t *testing.T) {
getTokenFunc = func(request AuthRequest) (*AuthTokenResponse, *AuthError) {
return nil, &AuthError{
Expand Down
9 changes: 3 additions & 6 deletions pkg/ocm/ocm.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,8 @@ import (

// DiscoverClusters returns a list of DiscoveredClusters found in both the accounts_mgmt and
// accounts_mgmt apis with the given filters
func DiscoverClusters(token string, baseURL string, baseAuthURL string, filters discovery.Filter) ([]discovery.DiscoveredCluster, error) {
func DiscoverClusters(authRequest auth.AuthRequest, filters discovery.Filter) ([]discovery.DiscoveredCluster, error) {
// Request ephemeral access token with user token. This will be used for OCM requests
authRequest := auth.AuthRequest{
Token: token,
BaseURL: baseAuthURL,
}
accessToken, err := auth.AuthClient.GetToken(authRequest)
if err != nil {
return nil, err
Expand All @@ -29,9 +25,10 @@ func DiscoverClusters(token string, baseURL string, baseAuthURL string, filters
// Get subscriptions from accounts_mgmt api
subscriptionRequestConfig := subscription.SubscriptionRequest{
Token: accessToken,
BaseURL: baseURL,
BaseURL: authRequest.BaseURL,
Filter: filters,
}

subscriptionClient := subscription.SubscriptionClientGenerator.NewClient(subscriptionRequestConfig)
subscriptions, err := subscriptionClient.GetSubscriptions()
if err != nil {
Expand Down
28 changes: 10 additions & 18 deletions pkg/ocm/ocm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,11 @@ func subscriptionResponse(testdata string) func() ([]subscription.Subscription,
}

func TestDiscoverClusters(t *testing.T) {
type args struct {
token string
baseURL string
baseAuthURL string
filters discovery.Filter
}
tests := []struct {
name string
authfunc func(auth.AuthRequest) (string, error)
subscriptionFunc func() ([]subscription.Subscription, error)
args args
authRequest auth.AuthRequest
want int
wantErr bool
}{
Expand All @@ -76,11 +70,10 @@ func TestDiscoverClusters(t *testing.T) {
},
// this mock returns 3 subscriptions read from mock_subscriptions.json
subscriptionFunc: subscriptionResponse("testdata/1_mock_subscription.json"),
args: args{
token: "test",
baseURL: "test",
baseAuthURL: "test",
filters: discovery.Filter{},
authRequest: auth.AuthRequest{
Token: "test",
BaseURL: "test",
BaseAuthURL: "test",
},
want: 1,
wantErr: false,
Expand All @@ -93,11 +86,10 @@ func TestDiscoverClusters(t *testing.T) {
},
// this mock returns 3 subscriptions read from mock_subscriptions.json
subscriptionFunc: subscriptionResponse("testdata/3_mock_subscriptions.json"),
args: args{
token: "test",
baseURL: "test",
baseAuthURL: "test",
filters: discovery.Filter{},
authRequest: auth.AuthRequest{
Token: "test",
BaseURL: "test",
BaseAuthURL: "test",
},
want: 2,
wantErr: false,
Expand All @@ -112,7 +104,7 @@ func TestDiscoverClusters(t *testing.T) {
// TODO: Running `getSubscriptionsFunc` should yield the subscriptions to test against, but we don't do this
getSubscriptionsFunc = tt.subscriptionFunc

got, err := DiscoverClusters(tt.args.token, tt.args.baseURL, tt.args.baseAuthURL, tt.args.filters)
got, err := DiscoverClusters(tt.authRequest, discovery.Filter{})
if (err != nil) != tt.wantErr {
t.Errorf("DiscoverClusters() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down
11 changes: 6 additions & 5 deletions pkg/ocm/subscription/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,12 @@ type SubscriptionRequest struct {
// SubscriptionError represents the error format response by OCM on a subscription request.
// Full list of responses available at https://api.openshift.com/api/accounts_mgmt/v1/errors/
type SubscriptionError struct {
Kind string `json:"kind"`
ID string `json:"id"`
Href string `json:"href"`
Code string `json:"code"`
Reason string `json:"reason"`
Kind string `json:"kind"`
ID string `json:"id"`
Href string `json:"href"`
Code string `json:"code"`
Reason string `json:"reason"`
StatusCode int `json:"status_code,omitempty"`
// Error is for setting an internal error for tracking
Error error `json:"-"`
// Response is for storing the raw response on an error
Expand Down
2 changes: 2 additions & 0 deletions pkg/ocm/subscription/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ func parseResponse(response *http.Response) (*SubscriptionResponse, *Subscriptio
errResponse.Error = fmt.Errorf("unexpected json response body")
errResponse.Response = bytes
}

errResponse.StatusCode = response.StatusCode
return nil, &errResponse
}

Expand Down
Loading
Loading