diff --git a/server/pkg/agent/agent_handler.go b/server/pkg/agent/agent_handler.go index 61c0d28ad..ea57935c6 100644 --- a/server/pkg/agent/agent_handler.go +++ b/server/pkg/agent/agent_handler.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/intelops/go-common/logging" + "github.com/kube-tarian/kad/server/pkg/config" "github.com/kube-tarian/kad/server/pkg/credential" oryclient "github.com/kube-tarian/kad/server/pkg/ory-client" "github.com/kube-tarian/kad/server/pkg/store" @@ -14,14 +15,16 @@ import ( type AgentHandler struct { log logging.Logger + cfg config.ServiceConfig agentMutex sync.RWMutex agents map[string]*Agent serverStore store.ServerStore oryClient oryclient.OryClient } -func NewAgentHandler(log logging.Logger, serverStore store.ServerStore, oryClient oryclient.OryClient) *AgentHandler { - return &AgentHandler{log: log, serverStore: serverStore, agents: map[string]*Agent{}, oryClient: oryClient} +func NewAgentHandler(log logging.Logger, cfg config.ServiceConfig, + serverStore store.ServerStore, oryClient oryclient.OryClient) *AgentHandler { + return &AgentHandler{log: log, cfg: cfg, serverStore: serverStore, agents: map[string]*Agent{}, oryClient: oryClient} } func (s *AgentHandler) AddAgent(clusterID string, agentCfg *Config) error { @@ -29,6 +32,8 @@ func (s *AgentHandler) AddAgent(clusterID string, agentCfg *Config) error { return nil } + agentCfg.ServicName = s.cfg.ServiceName + agentCfg.AuthEnabled = s.cfg.AuthEnabled agent, err := NewAgent(s.log, agentCfg, s.oryClient) if err != nil { return err diff --git a/server/pkg/agent/client.go b/server/pkg/agent/client.go index f1fef1daf..f0b095172 100644 --- a/server/pkg/agent/client.go +++ b/server/pkg/agent/client.go @@ -16,6 +16,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" ) type Config struct { @@ -24,6 +25,8 @@ type Config struct { CaCert string Cert string Key string + ServicName string + AuthEnabled bool } type Agent struct { @@ -33,9 +36,9 @@ type Agent struct { log logging.Logger } -func NewAgent(log logging.Logger, cfg *Config, oryclient oryclient.OryClient) (*Agent, error) { +func NewAgent(log logging.Logger, cfg *Config, oryClient oryclient.OryClient) (*Agent, error) { log.Infof("connecting to agent %s", cfg.Address) - conn, err := getConnection(cfg, oryclient) + conn, err := getConnection(cfg, oryClient) if err != nil { return nil, errors.WithMessage(err, "failed to connect to agent") } @@ -57,14 +60,16 @@ func NewAgent(log logging.Logger, cfg *Config, oryclient oryclient.OryClient) (* }, nil } -func getConnection(cfg *Config, client oryclient.OryClient) (*grpc.ClientConn, error) { +func getConnection(cfg *Config, oryClient oryclient.OryClient) (*grpc.ClientConn, error) { address, port, tls, err := parseAgentConnectionConfig(cfg.Address) if err != nil { return nil, err } - dialOptions := []grpc.DialOption{ - grpc.WithUnaryInterceptor(client.UnaryInterceptor), + dialOptions := []grpc.DialOption{} + + if cfg.AuthEnabled { + dialOptions = append(dialOptions, grpc.WithUnaryInterceptor(authInterceptor(oryClient, cfg.ServicName))) } if !tls { @@ -88,6 +93,30 @@ func (a *Agent) Close() { a.connection.Close() } +func authInterceptor(oryClient oryclient.OryClient, serviceName string) grpc.UnaryClientInterceptor { + return func( + ctx context.Context, + method string, + req, reply interface{}, + cc *grpc.ClientConn, + invoker grpc.UnaryInvoker, + opts ...grpc.CallOption, + ) error { + oauthCred, err := oryClient.GetServiceOauthCredential(ctx, serviceName) + if err != nil { + return err + } + + md := metadata.Pairs( + "oauth_token", oauthCred.AccessToken, + "ory_url", oauthCred.OryURL, + "ory_pat", oauthCred.OryPAT, + ) + newCtx := metadata.NewOutgoingContext(ctx, md) + return invoker(newCtx, method, req, reply, cc, opts...) + } +} + func loadTLSCredentials(config *Config) (credentials.TransportCredentials, error) { certPool := x509.NewCertPool() if !certPool.AppendCertsFromPEM([]byte(config.CaCert)) { diff --git a/server/pkg/api/cluster_apps.go b/server/pkg/api/cluster_apps.go index d9d54eba3..898089b73 100644 --- a/server/pkg/api/cluster_apps.go +++ b/server/pkg/api/cluster_apps.go @@ -142,7 +142,7 @@ func (s *Server) configureSSOForClusterApps(ctx context.Context, clusterID strin ReleaseName: app.ReleaseName, ClientId: clientID, ClientSecret: clientSecret, - OAuthBaseURL: s.iam.GetURL(), + OAuthBaseURL: s.iam.GetOAuthURL(), }) if err != nil || ssoResp.Status != 0 { diff --git a/server/pkg/api/interceptor.go b/server/pkg/api/interceptor.go index c301bddcb..b45e57e56 100644 --- a/server/pkg/api/interceptor.go +++ b/server/pkg/api/interceptor.go @@ -7,12 +7,7 @@ import ( "google.golang.org/grpc" ) -// UnaryInterceptor is a gRPC server-side interceptor that handles authentication for unary RPCs. -// It first attempts to retrieve an access token from the context using the ORY client interface. -// If the token retrieval is successful, it then tries to authorize the token using the ORY client interface. -// If either step fails, the interceptor logs the error and returns it, halting the RPC. -// If both steps are successful, the interceptor invokes the provided handler with the updated context and request. -func (s Server) AuthInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { +func (s *Server) AuthInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { accessToken, err := s.oryClient.GetSessionTokenFromContext(ctx) if err != nil { s.log.Debugf("error occured while fetching the token from the context. Error - %s", err.Error()) diff --git a/server/pkg/api/server.go b/server/pkg/api/server.go index c1433fc1b..d9aed1879 100644 --- a/server/pkg/api/server.go +++ b/server/pkg/api/server.go @@ -24,15 +24,15 @@ type Server struct { agentHandeler *agent.AgentHandler log logging.Logger oryClient oryclient.OryClient - iam iamclient.SecretManager + iam iamclient.IAMRegister cfg config.ServiceConfig } func NewServer(log logging.Logger, cfg config.ServiceConfig, serverStore store.ServerStore, - oryClient oryclient.OryClient, iam iamclient.SecretManager) (*Server, error) { + oryClient oryclient.OryClient, iam iamclient.IAMRegister) (*Server, error) { return &Server{ serverStore: serverStore, - agentHandeler: agent.NewAgentHandler(log, serverStore, oryClient), + agentHandeler: agent.NewAgentHandler(log, cfg, serverStore, oryClient), log: log, oryClient: oryClient, iam: iam, diff --git a/server/pkg/api/store_apps.go b/server/pkg/api/store_apps.go index 43bc04530..05b6856b2 100644 --- a/server/pkg/api/store_apps.go +++ b/server/pkg/api/store_apps.go @@ -3,7 +3,6 @@ package api import ( "context" - "github.com/kube-tarian/kad/server/pkg/agent" "github.com/kube-tarian/kad/server/pkg/pb/agentpb" "github.com/kube-tarian/kad/server/pkg/pb/serverpb" "github.com/kube-tarian/kad/server/pkg/types" @@ -275,8 +274,7 @@ func (s *Server) DeployStoreApp(ctx context.Context, request *serverpb.DeploySto } - agnetHandler := agent.NewAgentHandler(s.log, s.serverStore, s.oryClient) - agent, err := agnetHandler.GetAgent(clusterId) + agent, err := s.agentHandeler.GetAgent(clusterId) if err != nil { s.log.Errorf("failed to initialize agent, %v", err) return &serverpb.DeployStoreAppResponse{ diff --git a/server/pkg/config/config.go b/server/pkg/config/config.go index 9a1dcbeaf..1b535d05e 100644 --- a/server/pkg/config/config.go +++ b/server/pkg/config/config.go @@ -9,6 +9,7 @@ type ServiceConfig struct { ServerPort int `envconfig:"SERVER_PORT" default:"8080"` ServerGRPCHost string `envconfig:"SERVER_GRPC_HOST" default:"0.0.0.0"` ServerGRPCPort int `envconfig:"SERVER_GRPC_PORT" default:"8081"` + ServiceName string `envconfig:"SERVICE_NAME" default:"capten-server"` Database string `envconfig:"DATABASE" default:"astra"` AuthEnabled bool `envconfig:"AUTH_ENABLED" default:"false"` RegisterLaunchAppsConifg bool `envconfig:"REGISTER_LAUNCH_APPS_CONFIG" default:"false"` diff --git a/server/pkg/credential/client.go b/server/pkg/credential/client.go index 9398aef52..c48cd9e4c 100644 --- a/server/pkg/credential/client.go +++ b/server/pkg/credential/client.go @@ -8,11 +8,11 @@ import ( ) const ( - clusterCertEntity = "client-cert" - oauthIdentifier = "service-reg-identifier" - oauthEntityName = "service-reg" - iamClientKey = "IAM_CLIENTID" - iamSecretKey = "IAM_SECRET" + clusterCertEntity = "client-cert" + oauthIdentifier = "service-reg-identifier" + serviceClientOAuthEntityName = "service-client-oauth" + oauthClientIdKey = "CLIENT-ID" + oauthClientSecretKey = "CLIENT-SECRET" ) func GetServiceUserCredential(ctx context.Context, svcEntity, userName string) (cred credentials.ServiceCredential, err error) { @@ -88,49 +88,47 @@ func DeleteClusterCerts(ctx context.Context, clusterID string) (err error) { return } -func PutIamOauthCredential(ctx context.Context, clientid, secret string) error { - if clientid == "" || secret == "" { - return errors.New("either clientid or secret is missing, both are required") - } - +func StoreServiceOauthCredential(ctx context.Context, serviceName, clientId, clientSecret string) error { credWriter, err := credentials.NewCredentialAdmin(ctx) if err != nil { return errors.WithMessage(err, "error in initializing credential admin") } - credData := make(map[string]string) - credData[iamClientKey] = clientid - credData[iamSecretKey] = secret + cred := map[string]string{ + oauthClientIdKey: clientId, + oauthClientSecretKey: clientSecret, + } - err = credWriter.PutCredential(ctx, "generic", oauthEntityName, oauthIdentifier, credData) + err = credWriter.PutCredential(ctx, credentials.GenericCredentialType, + serviceClientOAuthEntityName, serviceName, cred) if err != nil { - return errors.WithMessage(err, "error while putting IAM credentials into the vault") + return errors.WithMessagef(err, "error while storing service oauth credential %s/%s into the vault", + serviceClientOAuthEntityName, serviceName) } - return nil } -func GetOauthCredentialFromVault(ctx context.Context, ClientKey, SecretKey string) (clientid, secret string, err error) { +func GetServiceOauthCredential(ctx context.Context, serviceName string) (clientId, clientSecret string, err error) { credReader, err := credentials.NewCredentialReader(ctx) if err != nil { - return "", "", errors.WithMessage(err, "error in initializing credential reader") + err = errors.WithMessage(err, "error in initializing credential reader") + return } - cred, err := credReader.GetCredential(ctx, "generic", oauthEntityName, oauthIdentifier) + cred, err := credReader.GetCredential(ctx, credentials.GenericCredentialType, + serviceClientOAuthEntityName, serviceName) if err != nil { - return "", "", errors.WithMessagef(err, "error in reading credential for %s/%s", oauthEntityName, oauthIdentifier) - } - - clientid, ok1 := cred[ClientKey] - secret, ok2 := cred[SecretKey] - - if !ok1 { - return "", "", errors.Errorf("credential with %s key is not present in generic credential type", iamClientKey) + err = errors.WithMessagef(err, "error while reading service oauth credential %s/%s from the vault", + serviceClientOAuthEntityName, serviceName) + return } - if !ok2 { - return "", "", errors.Errorf("credential with %s key is not present in generic credential type", iamSecretKey) + clientId = cred[oauthClientIdKey] + clientSecret = cred[oauthClientSecretKey] + if len(clientId) == 0 || len(clientSecret) == 0 { + err = errors.WithMessagef(err, "invalid service oauth credential %s/%s in the vault", + serviceClientOAuthEntityName, serviceName) + return } - - return clientid, secret, nil + return } diff --git a/server/pkg/handler/handle_agent_test.go b/server/pkg/handler/handle_agent_test.go index 4ad2c33ef..0eda53e31 100644 --- a/server/pkg/handler/handle_agent_test.go +++ b/server/pkg/handler/handle_agent_test.go @@ -15,6 +15,7 @@ import ( "github.com/intelops/go-common/logging" "github.com/kube-tarian/kad/server/api" "github.com/kube-tarian/kad/server/pkg/agent" + "github.com/kube-tarian/kad/server/pkg/config" "github.com/kube-tarian/kad/server/pkg/pb/agentpb" "github.com/stretchr/testify/require" ) @@ -36,7 +37,7 @@ func TestAPIHandler_Close(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.agentHandler.RemoveAgent(tt.args.customerId) }) @@ -56,7 +57,7 @@ func TestAPIHandler_CloseAll(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.Close() }) @@ -81,7 +82,7 @@ func TestAPIHandler_ConnectClient(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } if _, err := a.agentHandler.GetAgent(tt.args.customerId); (err != nil) != tt.wantErr { t.Errorf("ConnectClient() error = %v, wantErr %v", err, tt.wantErr) @@ -107,7 +108,7 @@ func TestAPIHandler_DeleteAgentClimondeploy(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.DeleteAgentClimondeploy(tt.args.c) }) @@ -131,7 +132,7 @@ func TestAPIHandler_DeleteAgentCluster(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.DeleteAgentCluster(tt.args.c) }) @@ -155,7 +156,7 @@ func TestAPIHandler_DeleteAgentDeploy(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.DeleteAgentDeploy(tt.args.c) }) @@ -179,7 +180,7 @@ func TestAPIHandler_DeleteAgentProject(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.DeleteAgentProject(tt.args.c) }) @@ -203,7 +204,7 @@ func TestAPIHandler_DeleteAgentRepository(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.DeleteAgentRepository(tt.args.c) }) @@ -227,7 +228,7 @@ func TestAPIHandler_GetAgentEndpoint(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.GetAgentEndpoint(tt.args.c) }) @@ -251,7 +252,7 @@ func TestAPIHandler_GetApiDocs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.GetApiDocs(tt.args.c) }) @@ -276,7 +277,7 @@ func TestAPIHandler_GetClient(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } if got, err := a.agentHandler.GetAgent(tt.args.customerId); err != nil && !reflect.DeepEqual(got, tt.want) { t.Errorf("GetClient() = %v, want %v", got, tt.want) @@ -302,7 +303,7 @@ func TestAPIHandler_GetStatus(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.GetStatus(tt.args.c) }) @@ -389,7 +390,7 @@ func TestAPIHandler_PostAgentApps(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.PostAgentApps(tt.args.c) }) @@ -413,7 +414,7 @@ func TestAPIHandler_PostAgentClimondeploy(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.PostAgentClimondeploy(tt.args.c) }) @@ -437,7 +438,7 @@ func TestAPIHandler_PostAgentCluster(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.PostAgentCluster(tt.args.c) }) @@ -461,7 +462,7 @@ func TestAPIHandler_PostAgentDeploy(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.PostAgentDeploy(tt.args.c) }) @@ -485,7 +486,7 @@ func TestAPIHandler_PostAgentEndpoint(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.PostAgentEndpoint(tt.args.c) }) @@ -509,7 +510,7 @@ func TestAPIHandler_PostAgentProject(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.PostAgentProject(tt.args.c) }) @@ -533,7 +534,7 @@ func TestAPIHandler_PostAgentRepository(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.PostAgentRepository(tt.args.c) }) @@ -557,7 +558,7 @@ func TestAPIHandler_PostAgentSecret(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.PostAgentSecret(tt.args.c) }) @@ -581,7 +582,7 @@ func TestAPIHandler_PutAgentClimondeploy(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.PutAgentClimondeploy(tt.args.c) }) @@ -605,7 +606,7 @@ func TestAPIHandler_PutAgentDeploy(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.PutAgentDeploy(tt.args.c) }) @@ -629,7 +630,7 @@ func TestAPIHandler_PutAgentEndpoint(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.PutAgentEndpoint(tt.args.c) }) @@ -653,7 +654,7 @@ func TestAPIHandler_PutAgentProject(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.PutAgentProject(tt.args.c) }) @@ -677,7 +678,7 @@ func TestAPIHandler_PutAgentRepository(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.PutAgentRepository(tt.args.c) }) @@ -704,7 +705,7 @@ func TestAPIHandler_getFileContent(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } got, err := a.getFileContent(tt.args.c, tt.args.fileInfo) if (err != nil) != tt.wantErr { @@ -737,7 +738,7 @@ func TestAPIHandler_sendResponse(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.sendResponse(tt.args.c, tt.args.msg, tt.args.err) }) @@ -763,7 +764,7 @@ func TestAPIHandler_setFailedResponse(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &APIHandler{ - agentHandler: agent.NewAgentHandler(logging.NewLogger(), nil, nil), + agentHandler: agent.NewAgentHandler(logging.NewLogger(), config.ServiceConfig{}, nil, nil), } a.setFailedResponse(tt.args.c, tt.args.msg, tt.args.err) }) diff --git a/server/pkg/handler/handler.go b/server/pkg/handler/handler.go index 0da6f4997..381735cc5 100644 --- a/server/pkg/handler/handler.go +++ b/server/pkg/handler/handler.go @@ -9,6 +9,7 @@ import ( "github.com/intelops/go-common/logging" "github.com/kube-tarian/kad/server/api" "github.com/kube-tarian/kad/server/pkg/agent" + "github.com/kube-tarian/kad/server/pkg/config" oryclient "github.com/kube-tarian/kad/server/pkg/ory-client" "github.com/kube-tarian/kad/server/pkg/store" ) @@ -26,7 +27,7 @@ var ( func NewAPIHandler(log logging.Logger, serverStore store.ServerStore, oryClient oryclient.OryClient) (*APIHandler, error) { return &APIHandler{ log: log, - agentHandler: agent.NewAgentHandler(log, serverStore, oryClient), + agentHandler: agent.NewAgentHandler(log, config.ServiceConfig{}, serverStore, oryClient), }, nil } diff --git a/server/pkg/iam-client/client.go b/server/pkg/iam-client/client.go index d2835dff3..16321658f 100644 --- a/server/pkg/iam-client/client.go +++ b/server/pkg/iam-client/client.go @@ -14,9 +14,9 @@ import ( "google.golang.org/grpc/metadata" ) -type SecretManager interface { +type IAMRegister interface { RegisterAppClientSecrets(ctx context.Context, clientName, redirectURL string) (string, string, error) - GetURL() string + GetOAuthURL() string } type Client struct { @@ -33,13 +33,13 @@ func NewClient(log logging.Logger, ory oryclient.OryClient, cfg Config) (*Client }, nil } -func (c *Client) RegisterWithIam() error { +func (c *Client) RegisterService() error { conn, err := grpc.Dial(c.cfg.IAMURL, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return err } - iamclient := iampb.NewOauthServiceClient(conn) + iamclient := iampb.NewOauthServiceClient(conn) oauthClientReq := &iampb.CreateClientCredentialsClientRequest{ ClientName: c.cfg.ServiceName, } @@ -47,73 +47,62 @@ func (c *Client) RegisterWithIam() error { if err != nil { return err } - err = credential.PutIamOauthCredential(context.Background(), res.ClientId, res.ClientSecret) + + err = credential.StoreServiceOauthCredential(context.Background(), c.cfg.ServiceName, res.ClientId, res.ClientSecret) if err != nil { return err } return nil } -// at the line cm.WithIamYamlPath("provide the yaml location here"), -// the roles and actions should be added to ConfigMap -// the the location should be provided func (c *Client) RegisterRolesActions() error { grpcOpts := []grpc.DialOption{ grpc.WithTransportCredentials(insecure.NewCredentials()), } - // Create an instance of IamConn with desired options - // the order of calling the options should be same as given in example + iamConn := cm.NewIamConn( cm.WithGrpcDialOption(grpcOpts...), cm.WithIamAddress(c.cfg.IAMURL), - // TODO: here need to add the roles and actions yaml location cm.WithIamYamlPath("provide the yaml location here"), ) + ctx := context.Background() - tkn, err := c.oryClient.GetCaptenServiceRegOauthToken() + oauthCred, err := c.oryClient.GetServiceOauthCredential(ctx, c.cfg.ServiceName) if err != nil { - err = errors.WithMessage(err, "error getting capten service reg oauth token") - return err - } - if tkn == nil { - return errors.New("capten service reg oauth token is nil") + return errors.WithMessage(err, "error while getting service oauth token") } + md := metadata.Pairs( - "oauth_token", *tkn, + "oauth_token", oauthCred.AccessToken, "ory_url", c.oryClient.GetURL(), "ory_pat", c.oryClient.GetPAT(), ) + newCtx := metadata.NewOutgoingContext(ctx, md) - // Update action roles err = iamConn.UpdateActionRoles(newCtx) if err != nil { - c.log.Errorf("Failed to update action roles: %v", err) - return err + return errors.WithMessage(err, "Failed to update action roles") } return nil } -func (c *Client) GetURL() string { +func (c *Client) GetOAuthURL() string { return c.oryClient.GetURL() } func (c *Client) RegisterAppClientSecrets(ctx context.Context, clientName, redirectURL string) (string, string, error) { - conn, err := grpc.Dial(c.cfg.IAMURL, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return "", "", err } - defer conn.Close() iamclient := iampb.NewOauthServiceClient(conn) - res, err := iamclient.CreateOauthClient(context.Background(), &iampb.OauthClientRequest{ ClientName: clientName, RedirectUris: []string{redirectURL}, }) if err != nil { return "", "", err } - return res.ClientId, res.ClientSecret, nil } diff --git a/server/pkg/iam-client/register_service.go b/server/pkg/iam-client/register_service.go index 11add8fcb..d54796aad 100644 --- a/server/pkg/iam-client/register_service.go +++ b/server/pkg/iam-client/register_service.go @@ -38,17 +38,12 @@ func RegisterService(log logging.Logger) error { return errors.WithMessage(err, "OryClient initialization failed") } - IC, err := NewClient(log, oryclient, cfg) + iamClient, err := NewClient(log, oryclient, cfg) if err != nil { return errors.WithMessage(err, "Error occured while created IAM client") } - err = IC.RegisterWithIam() - if err != nil { - return errors.WithMessage(err, "Registering capten server as oauth client failed") - } - - err = IC.RegisterRolesActions() + err = iamClient.RegisterRolesActions() if err != nil { return errors.WithMessage(err, "Registering Roles and Actions in IAM failed") } diff --git a/server/pkg/ory-client/client.go b/server/pkg/ory-client/client.go index e5e5f6abb..625b47855 100644 --- a/server/pkg/ory-client/client.go +++ b/server/pkg/ory-client/client.go @@ -4,40 +4,26 @@ import ( "context" "strings" - "github.com/intelops/go-common/credentials" "github.com/intelops/go-common/logging" "github.com/kelseyhightower/envconfig" "github.com/kube-tarian/kad/server/pkg/credential" ory "github.com/ory/client-go" "github.com/pkg/errors" "golang.org/x/oauth2/clientcredentials" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) -const ( - iamClientKey = "IAM_CLIENTID" - iamSecretKey = "IAM_SECRET" -) - -// Config represents the configuration settings required for -// fetching ory entities from the vault -// also for integration with ORY and create a OryApiClient. type Config struct { OryEntityName string `envconfig:"ORY_ENTITY_NAME" required:"true"` CredentialIdentifier string `envconfig:"ORY_CRED_IDENTIFIER" required:"true"` } -// TokenConfig represents the configuration settings required for -// fetching the client id and secret from the vault -// also for fetching oauth token from IAM -type TokenConfig struct { - CaptenServiceEntity string `envconfig:"CAPTEN_SERVER_ENTITY" required:"true"` - CaptenServiceIdenity string `envconfig:"CAPTEN_SERVER_IDENTIFIER" required:"true"` - CaptenClientKey string `envconfig:"CAPTEN_CLIENT_KEY" required:"true"` - CaptenClientSecret string `envconfig:"CAPTEN_CLIENT_SECRET" required:"true"` +type OauthAccessCredential struct { + OryPAT string + OryURL string + AccessToken string } type Client struct { @@ -50,13 +36,11 @@ type Client struct { type OryClient interface { GetSessionTokenFromContext(ctx context.Context) (string, error) Authorize(ctx context.Context, accessToken string) (context.Context, error) - UnaryInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error - GetCaptenServiceRegOauthToken() (*string, error) + GetServiceOauthCredential(ctx context.Context, serviceName string) (*OauthAccessCredential, error) GetURL() string GetPAT() string } -// NewOryClient returns a OryClient interface func NewOryClient(log logging.Logger) (OryClient, error) { cfg := &Config{} if err := envconfig.Process("", cfg); err != nil { @@ -70,7 +54,7 @@ func NewOryClient(log logging.Logger) (OryClient, error) { } oryPAT := serviceCredential.AdditionalData["ORY_PAT"] oryURL := serviceCredential.AdditionalData["ORY_URL"] - conn := NewOrySdk(log, oryURL) + conn := newOryAPIClient(log, oryURL) return &Client{ oryPAT: oryPAT, conn: conn, @@ -79,6 +63,14 @@ func NewOryClient(log logging.Logger) (OryClient, error) { }, nil } +func newOryAPIClient(log logging.Logger, oryURL string) *ory.APIClient { + config := ory.NewConfiguration() + config.Servers = ory.ServerConfigurations{{ + URL: oryURL, + }} + return ory.NewAPIClient(config) +} + func (c *Client) GetURL() string { return c.oryURL } @@ -87,38 +79,22 @@ func (c *Client) GetPAT() string { return c.oryPAT } -func getTokenEnv() (*TokenConfig, error) { - cfg := &TokenConfig{} - if err := envconfig.Process("", cfg); err != nil { - return nil, err - } - return cfg, nil -} - -// NewOrySdk creates a oryAPIClient using the oryURL -// and returns it -func NewOrySdk(log logging.Logger, oryURL string) *ory.APIClient { - log.Info("creating a ory client") - config := ory.NewConfiguration() - config.Servers = ory.ServerConfigurations{{ - URL: oryURL, - }} - - return ory.NewAPIClient(config) +func (c *Client) GetOryTokenUrl() string { + tokenUrl := c.oryURL + "/oauth2/token" + return tokenUrl } -// GetSessionTokenFromContext fetches the session token from the context -// and returns the token and nil for the error. -// But if any error occurs while fetching the token it returns an empty string and an error. func (c *Client) GetSessionTokenFromContext(ctx context.Context) (string, error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { return "", status.Error(codes.Unauthenticated, "Failed to get metadata from context") } + bearerToken := md.Get("authorization") if len(bearerToken) == 0 || len(strings.Split(bearerToken[0], " ")) != 2 { return "", status.Error(codes.Unauthenticated, "No access token provided") } + accessToken := bearerToken[0] if len(accessToken) < 8 || accessToken[:7] != "Bearer " { return "", status.Error(codes.Unauthenticated, "Invalid access token") @@ -126,111 +102,41 @@ func (c *Client) GetSessionTokenFromContext(ctx context.Context) (string, error) return accessToken[7:], nil } -// Authorize checks whether the accesstoken is valid or Invalid using the ory.APIClient -// It checks token is active or not active -// If token is active its a valid token -// If token is not active its a invalid token func (c *Client) Authorize(ctx context.Context, accessToken string) (context.Context, error) { ctx = context.WithValue(ctx, ory.ContextAccessToken, c.oryPAT) sessionInfo, _, err := c.conn.IdentityApi.GetSession(ctx, accessToken).Execute() if err != nil { - c.log.Errorf("Error occured while getting session info for session id - "+accessToken+"+\nError - %v", err.Error()) - return ctx, status.Errorf(codes.Unauthenticated, "Failed to introspect session id - %v", err) + return ctx, status.Errorf(codes.Unauthenticated, "Failed to introspect session id, %v", err) } + c.log.Infof("session id: %v", sessionInfo.Id) if !sessionInfo.GetActive() { - c.log.Errorf("Error occured while getting session info for session id - "+accessToken+"+\nError - %v", err.Error()) return ctx, status.Error(codes.Unauthenticated, "session id is not active") } return ctx, nil } -func (c *Client) GetOryTokenUrl() string { - tokenUrl := c.oryURL + "/oauth2/token" - return tokenUrl -} -func (c *Client) GetCaptenOauthToken(ctx context.Context, ClientKey, SecretKey string) (context.Context, error) { - clientid, secret, err := credential.GetOauthCredentialFromVault(ctx, ClientKey, SecretKey) + +func (c *Client) GetServiceOauthCredential(ctx context.Context, serviceName string) (*OauthAccessCredential, error) { + clientId, clientSecret, err := credential.GetServiceOauthCredential(ctx, serviceName) if err != nil { - c.log.Errorf("error while getting clientid and secret from vault: %v", err.Error()) - return ctx, err + return nil, err } conf := &clientcredentials.Config{ - ClientID: clientid, - ClientSecret: secret, + ClientID: clientId, + ClientSecret: clientSecret, Scopes: []string{"openid email offline"}, TokenURL: c.GetOryTokenUrl(), } - at, err := conf.Token(ctx) - if err != nil { - c.log.Errorf("error while fetching oauth token from oryapiclient ERROR: %v", err.Error()) - return ctx, err - } - md := metadata.Pairs("oauth_token", at.AccessToken, - "ory_url", c.oryURL, - "ory_pat", c.oryPAT, - ) - newCtx := metadata.NewOutgoingContext(ctx, md) - return newCtx, nil -} - -func (c *Client) UnaryInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - newCtx, err := c.GetCaptenOauthToken(ctx, iamClientKey, iamSecretKey) - if err != nil { - return err - } - return invoker(newCtx, method, req, reply, cc, opts...) -} - -func (c *Client) GetCaptenServiceRegOauthToken() (*string, error) { - cfg, err := getTokenEnv() - if err != nil { - return nil, err - } - data, err := GetFromVault(context.Background(), cfg.CaptenServiceEntity, cfg.CaptenServiceIdenity) - if err != nil { - return nil, err - } - - clientid, ok := data[cfg.CaptenClientKey] - if !ok { - return nil, errors.New("capten service client id not found in vault data") - } - - clientSecret, ok := data[cfg.CaptenClientSecret] - if !ok { - return nil, errors.New("capten service client secret not found in vault data") - } - - ctxWithToken, err := c.GetCaptenOauthToken(context.Background(), clientid, clientSecret) + oauthToken, err := conf.Token(ctx) if err != nil { - return nil, err + return nil, errors.WithMessagef(err, "error while fetching oauth token") } - // Extract the token from the context - md, ok := metadata.FromOutgoingContext(ctxWithToken) - if !ok { - return nil, errors.New("failed to extract metadata from context") - } - - token, ok := md["oauth_token"] - if !ok || len(token) == 0 { - return nil, errors.New("oauth_token not found in context metadata") - } - - return &token[0], nil -} - -func GetFromVault(ctx context.Context, en, iden string) (map[string]string, error) { - credReader, err := credentials.NewCredentialReader(ctx) - if err != nil { - err = errors.WithMessage(err, "error in initializing credential reader") - return nil, err - } - cred, err := credReader.GetCredential(ctx, "generic", en, iden) - if err != nil { - return nil, err - } - return cred, nil + return &OauthAccessCredential{ + OryPAT: c.oryPAT, + OryURL: c.oryURL, + AccessToken: oauthToken.AccessToken, + }, nil }