Skip to content

Commit

Permalink
refactor service registration actions and ory auth
Browse files Browse the repository at this point in the history
  • Loading branch information
vramk23 committed Aug 21, 2023
1 parent aa7f0e9 commit b1b3f6d
Show file tree
Hide file tree
Showing 13 changed files with 158 additions and 241 deletions.
9 changes: 7 additions & 2 deletions server/pkg/agent/agent_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"sync"

"github.com/intelops/go-common/logging"
"github.com/kube-tarian/kad/server/pkg/config"
"github.com/kube-tarian/kad/server/pkg/credential"
oryclient "github.com/kube-tarian/kad/server/pkg/ory-client"
"github.com/kube-tarian/kad/server/pkg/store"
Expand All @@ -14,21 +15,25 @@ import (

type AgentHandler struct {
log logging.Logger
cfg config.ServiceConfig
agentMutex sync.RWMutex
agents map[string]*Agent
serverStore store.ServerStore
oryClient oryclient.OryClient
}

func NewAgentHandler(log logging.Logger, serverStore store.ServerStore, oryClient oryclient.OryClient) *AgentHandler {
return &AgentHandler{log: log, serverStore: serverStore, agents: map[string]*Agent{}, oryClient: oryClient}
func NewAgentHandler(log logging.Logger, cfg config.ServiceConfig,
serverStore store.ServerStore, oryClient oryclient.OryClient) *AgentHandler {
return &AgentHandler{log: log, cfg: cfg, serverStore: serverStore, agents: map[string]*Agent{}, oryClient: oryClient}
}

func (s *AgentHandler) AddAgent(clusterID string, agentCfg *Config) error {
if _, ok := s.agents[clusterID]; ok {
return nil
}

agentCfg.ServicName = s.cfg.ServiceName
agentCfg.AuthEnabled = s.cfg.AuthEnabled
agent, err := NewAgent(s.log, agentCfg, s.oryClient)
if err != nil {
return err
Expand Down
39 changes: 34 additions & 5 deletions server/pkg/agent/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
)

type Config struct {
Expand All @@ -24,6 +25,8 @@ type Config struct {
CaCert string
Cert string
Key string
ServicName string
AuthEnabled bool
}

type Agent struct {
Expand All @@ -33,9 +36,9 @@ type Agent struct {
log logging.Logger
}

func NewAgent(log logging.Logger, cfg *Config, oryclient oryclient.OryClient) (*Agent, error) {
func NewAgent(log logging.Logger, cfg *Config, oryClient oryclient.OryClient) (*Agent, error) {
log.Infof("connecting to agent %s", cfg.Address)
conn, err := getConnection(cfg, oryclient)
conn, err := getConnection(cfg, oryClient)
if err != nil {
return nil, errors.WithMessage(err, "failed to connect to agent")
}
Expand All @@ -57,14 +60,16 @@ func NewAgent(log logging.Logger, cfg *Config, oryclient oryclient.OryClient) (*
}, nil
}

func getConnection(cfg *Config, client oryclient.OryClient) (*grpc.ClientConn, error) {
func getConnection(cfg *Config, oryClient oryclient.OryClient) (*grpc.ClientConn, error) {
address, port, tls, err := parseAgentConnectionConfig(cfg.Address)
if err != nil {
return nil, err
}

dialOptions := []grpc.DialOption{
grpc.WithUnaryInterceptor(client.UnaryInterceptor),
dialOptions := []grpc.DialOption{}

if cfg.AuthEnabled {
dialOptions = append(dialOptions, grpc.WithUnaryInterceptor(authInterceptor(oryClient, cfg.ServicName)))
}

if !tls {
Expand All @@ -88,6 +93,30 @@ func (a *Agent) Close() {
a.connection.Close()
}

func authInterceptor(oryClient oryclient.OryClient, serviceName string) grpc.UnaryClientInterceptor {
return func(
ctx context.Context,
method string,
req, reply interface{},
cc *grpc.ClientConn,
invoker grpc.UnaryInvoker,
opts ...grpc.CallOption,
) error {
oauthCred, err := oryClient.GetServiceOauthCredential(ctx, serviceName)
if err != nil {
return err
}

md := metadata.Pairs(
"oauth_token", oauthCred.AccessToken,
"ory_url", oauthCred.OryURL,
"ory_pat", oauthCred.OryPAT,
)
newCtx := metadata.NewOutgoingContext(ctx, md)
return invoker(newCtx, method, req, reply, cc, opts...)
}
}

func loadTLSCredentials(config *Config) (credentials.TransportCredentials, error) {
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM([]byte(config.CaCert)) {
Expand Down
2 changes: 1 addition & 1 deletion server/pkg/api/cluster_apps.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func (s *Server) configureSSOForClusterApps(ctx context.Context, clusterID strin
ReleaseName: app.ReleaseName,
ClientId: clientID,
ClientSecret: clientSecret,
OAuthBaseURL: s.iam.GetURL(),
OAuthBaseURL: s.iam.GetOAuthURL(),
})

if err != nil || ssoResp.Status != 0 {
Expand Down
7 changes: 1 addition & 6 deletions server/pkg/api/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@ import (
"google.golang.org/grpc"
)

// UnaryInterceptor is a gRPC server-side interceptor that handles authentication for unary RPCs.
// It first attempts to retrieve an access token from the context using the ORY client interface.
// If the token retrieval is successful, it then tries to authorize the token using the ORY client interface.
// If either step fails, the interceptor logs the error and returns it, halting the RPC.
// If both steps are successful, the interceptor invokes the provided handler with the updated context and request.
func (s Server) AuthInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
func (s *Server) AuthInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
accessToken, err := s.oryClient.GetSessionTokenFromContext(ctx)
if err != nil {
s.log.Debugf("error occured while fetching the token from the context. Error - %s", err.Error())
Expand Down
6 changes: 3 additions & 3 deletions server/pkg/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ type Server struct {
agentHandeler *agent.AgentHandler
log logging.Logger
oryClient oryclient.OryClient
iam iamclient.SecretManager
iam iamclient.IAMRegister
cfg config.ServiceConfig
}

func NewServer(log logging.Logger, cfg config.ServiceConfig, serverStore store.ServerStore,
oryClient oryclient.OryClient, iam iamclient.SecretManager) (*Server, error) {
oryClient oryclient.OryClient, iam iamclient.IAMRegister) (*Server, error) {
return &Server{
serverStore: serverStore,
agentHandeler: agent.NewAgentHandler(log, serverStore, oryClient),
agentHandeler: agent.NewAgentHandler(log, cfg, serverStore, oryClient),
log: log,
oryClient: oryClient,
iam: iam,
Expand Down
4 changes: 1 addition & 3 deletions server/pkg/api/store_apps.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package api
import (
"context"

"github.com/kube-tarian/kad/server/pkg/agent"
"github.com/kube-tarian/kad/server/pkg/pb/agentpb"
"github.com/kube-tarian/kad/server/pkg/pb/serverpb"
"github.com/kube-tarian/kad/server/pkg/types"
Expand Down Expand Up @@ -275,8 +274,7 @@ func (s *Server) DeployStoreApp(ctx context.Context, request *serverpb.DeploySto

}

agnetHandler := agent.NewAgentHandler(s.log, s.serverStore, s.oryClient)
agent, err := agnetHandler.GetAgent(clusterId)
agent, err := s.agentHandeler.GetAgent(clusterId)
if err != nil {
s.log.Errorf("failed to initialize agent, %v", err)
return &serverpb.DeployStoreAppResponse{
Expand Down
1 change: 1 addition & 0 deletions server/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ type ServiceConfig struct {
ServerPort int `envconfig:"SERVER_PORT" default:"8080"`
ServerGRPCHost string `envconfig:"SERVER_GRPC_HOST" default:"0.0.0.0"`
ServerGRPCPort int `envconfig:"SERVER_GRPC_PORT" default:"8081"`
ServiceName string `envconfig:"SERVICE_NAME" default:"capten-server"`
Database string `envconfig:"DATABASE" default:"astra"`
AuthEnabled bool `envconfig:"AUTH_ENABLED" default:"false"`
RegisterLaunchAppsConifg bool `envconfig:"REGISTER_LAUNCH_APPS_CONFIG" default:"false"`
Expand Down
59 changes: 28 additions & 31 deletions server/pkg/credential/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ import (
)

const (
clusterCertEntity = "client-cert"
oauthIdentifier = "service-reg-identifier"
oauthEntityName = "service-reg"
iamClientKey = "IAM_CLIENTID"
iamSecretKey = "IAM_SECRET"
clusterCertEntity = "client-cert"
serviceClientOAuthEntityName = "service-client-oauth"
oauthClientIdKey = "CLIENT-ID"
oauthClientSecretKey = "CLIENT-SECRET"
)

func GetServiceUserCredential(ctx context.Context, svcEntity, userName string) (cred credentials.ServiceCredential, err error) {
Expand Down Expand Up @@ -88,49 +87,47 @@ func DeleteClusterCerts(ctx context.Context, clusterID string) (err error) {
return
}

func PutIamOauthCredential(ctx context.Context, clientid, secret string) error {
if clientid == "" || secret == "" {
return errors.New("either clientid or secret is missing, both are required")
}

func StoreServiceOauthCredential(ctx context.Context, serviceName, clientId, clientSecret string) error {
credWriter, err := credentials.NewCredentialAdmin(ctx)
if err != nil {
return errors.WithMessage(err, "error in initializing credential admin")
}

credData := make(map[string]string)
credData[iamClientKey] = clientid
credData[iamSecretKey] = secret
cred := map[string]string{
oauthClientIdKey: clientId,
oauthClientSecretKey: clientSecret,
}

err = credWriter.PutCredential(ctx, "generic", oauthEntityName, oauthIdentifier, credData)
err = credWriter.PutCredential(ctx, credentials.GenericCredentialType,
serviceClientOAuthEntityName, serviceName, cred)
if err != nil {
return errors.WithMessage(err, "error while putting IAM credentials into the vault")
return errors.WithMessagef(err, "error while storing service oauth credential %s/%s into the vault",
serviceClientOAuthEntityName, serviceName)
}

return nil
}

func GetOauthCredentialFromVault(ctx context.Context, ClientKey, SecretKey string) (clientid, secret string, err error) {
func GetServiceOauthCredential(ctx context.Context, serviceName string) (clientId, clientSecret string, err error) {
credReader, err := credentials.NewCredentialReader(ctx)
if err != nil {
return "", "", errors.WithMessage(err, "error in initializing credential reader")
err = errors.WithMessage(err, "error in initializing credential reader")
return
}

cred, err := credReader.GetCredential(ctx, "generic", oauthEntityName, oauthIdentifier)
cred, err := credReader.GetCredential(ctx, credentials.GenericCredentialType,
serviceClientOAuthEntityName, serviceName)
if err != nil {
return "", "", errors.WithMessagef(err, "error in reading credential for %s/%s", oauthEntityName, oauthIdentifier)
}

clientid, ok1 := cred[ClientKey]
secret, ok2 := cred[SecretKey]

if !ok1 {
return "", "", errors.Errorf("credential with %s key is not present in generic credential type", iamClientKey)
err = errors.WithMessagef(err, "error while reading service oauth credential %s/%s from the vault",
serviceClientOAuthEntityName, serviceName)
return
}

if !ok2 {
return "", "", errors.Errorf("credential with %s key is not present in generic credential type", iamSecretKey)
clientId = cred[oauthClientIdKey]
clientSecret = cred[oauthClientSecretKey]
if len(clientId) == 0 || len(clientSecret) == 0 {
err = errors.WithMessagef(err, "invalid service oauth credential %s/%s in the vault",
serviceClientOAuthEntityName, serviceName)
return
}

return clientid, secret, nil
return
}
Loading

0 comments on commit b1b3f6d

Please sign in to comment.