Skip to content

Commit

Permalink
Merge pull request #417 from authzed/fix-context-override
Browse files Browse the repository at this point in the history
fixes context override
  • Loading branch information
vroldanbet authored Sep 18, 2024
2 parents c03fcf6 + f5626a8 commit af2bed5
Show file tree
Hide file tree
Showing 9 changed files with 286 additions and 78 deletions.
116 changes: 100 additions & 16 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/mitchellh/go-homedir"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
grpc "google.golang.org/grpc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"

zgrpcutil "github.com/authzed/zed/internal/grpcutil"
Expand All @@ -28,20 +28,17 @@ type Client interface {
}

// NewClient defines an (overridable) means of creating a new client.
var NewClient = newGRPCClient
var (
NewClient = newClientForCurrentContext
NewClientForContext = newClientForContext
)

func newGRPCClient(cmd *cobra.Command) (Client, error) {
func newClientForCurrentContext(cmd *cobra.Command) (Client, error) {
configStore, secretStore := DefaultStorage()
token, err := storage.DefaultToken(
cobrautil.MustGetString(cmd, "endpoint"),
cobrautil.MustGetString(cmd, "token"),
configStore,
secretStore,
)
token, err := GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore)
if err != nil {
return nil, err
}
log.Trace().Interface("token", token).Send()

dialOpts, err := DialOptsFromFlags(cmd, token)
if err != nil {
Expand All @@ -56,28 +53,115 @@ func newGRPCClient(cmd *cobra.Command) (Client, error) {
return client, err
}

func newClientForContext(cmd *cobra.Command, contextName string, secretStore storage.SecretStore) (*authzed.Client, error) {
currentToken, err := storage.GetToken(contextName, secretStore)
if err != nil {
return nil, err
}

token, err := GetTokenWithCLIOverride(cmd, currentToken)
if err != nil {
return nil, err
}

dialOpts, err := DialOptsFromFlags(cmd, token)
if err != nil {
return nil, err
}

return authzed.NewClient(token.Endpoint, dialOpts...)
}

// GetCurrentTokenWithCLIOverride returns the current token, but overridden by any parameter specified via CLI args
func GetCurrentTokenWithCLIOverride(cmd *cobra.Command, configStore storage.ConfigStore, secretStore storage.SecretStore) (storage.Token, error) {
token, err := storage.CurrentToken(
configStore,
secretStore,
)
if err != nil {
return storage.Token{}, err
}

return GetTokenWithCLIOverride(cmd, token)
}

// GetTokenWithCLIOverride returns the provided token, but overridden by any parameter specified explicitly via command
// flags
func GetTokenWithCLIOverride(cmd *cobra.Command, token storage.Token) (storage.Token, error) {
overrideToken, err := tokenFromCli(cmd)
if err != nil {
return storage.Token{}, err
}

result, err := storage.TokenWithOverride(
overrideToken,
token,
)
if err != nil {
return storage.Token{}, err
}

log.Trace().Bool("context-override-via-cli", overrideToken.AnyValue()).Interface("context", result).Send()
return result, nil
}

func tokenFromCli(cmd *cobra.Command) (storage.Token, error) {
certPath := cobrautil.MustGetStringExpanded(cmd, "certificate-path")
var certBytes []byte
var err error
if certPath != "" {
certBytes, err = os.ReadFile(certPath)
if err != nil {
return storage.Token{}, fmt.Errorf("failed to read ceritficate: %w", err)
}
}

explicitInsecure := cmd.Flags().Changed("insecure")
var notSecure *bool
if explicitInsecure {
i := cobrautil.MustGetBool(cmd, "insecure")
notSecure = &i
}

explicitNoVerifyCA := cmd.Flags().Changed("no-verify-ca")
var notVerifyCA *bool
if explicitNoVerifyCA {
nvc := cobrautil.MustGetBool(cmd, "no-verify-ca")
notVerifyCA = &nvc
}
overrideToken := storage.Token{
APIToken: cobrautil.MustGetString(cmd, "token"),
Endpoint: cobrautil.MustGetString(cmd, "endpoint"),
Insecure: notSecure,
NoVerifyCA: notVerifyCA,
CACert: certBytes,
}
return overrideToken, nil
}

// DefaultStorage returns the default configured config store and secret store.
func DefaultStorage() (storage.ConfigStore, storage.SecretStore) {
var home string
if xdg := os.Getenv("XDG_CONFIG_HOME"); xdg != "" {
home = filepath.Join(xdg, "zed")
} else {
homedir, _ := homedir.Dir()
home = filepath.Join(homedir, ".zed")
hmdir, _ := homedir.Dir()
home = filepath.Join(hmdir, ".zed")
}
return &storage.JSONConfigStore{ConfigPath: home},
&storage.KeychainSecretStore{ConfigPath: home}
}

func certOption(cmd *cobra.Command, token storage.Token) (opt grpc.DialOption, err error) {
func certOption(token storage.Token) (opt grpc.DialOption, err error) {
verification := grpcutil.VerifyCA
if cobrautil.MustGetBool(cmd, "no-verify-ca") || token.HasNoVerifyCA() {
if token.HasNoVerifyCA() {
verification = grpcutil.SkipVerifyCA
}

if certBytes, ok := token.Certificate(); ok {
return grpcutil.WithCustomCertBytes(verification, certBytes)
}

return grpcutil.WithSystemCerts(verification)
}

Expand All @@ -96,12 +180,12 @@ func DialOptsFromFlags(cmd *cobra.Command, token storage.Token) ([]grpc.DialOpti
grpc.WithChainStreamInterceptor(zgrpcutil.StreamLogDispatchTrailers),
}

if cobrautil.MustGetBool(cmd, "insecure") || (token.IsInsecure()) {
if token.IsInsecure() {
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
opts = append(opts, grpcutil.WithInsecureBearerToken(token.APIToken))
} else {
opts = append(opts, grpcutil.WithBearerToken(token.APIToken))
certOpt, err := certOption(cmd, token)
certOpt, err := certOption(token)
if err != nil {
return nil, fmt.Errorf("failed to configure TLS cert: %w", err)
}
Expand Down
62 changes: 62 additions & 0 deletions internal/client/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package client_test

import (
"os"
"testing"

"github.com/authzed/zed/internal/client"
"github.com/authzed/zed/internal/storage"
zedtesting "github.com/authzed/zed/internal/testing"

"github.com/stretchr/testify/require"
)

func TestGetTokenWithCLIOverride(t *testing.T) {
testCert, err := os.CreateTemp("", "")
require.NoError(t, err)
_, err = testCert.Write([]byte("hi"))
require.NoError(t, err)
cmd := zedtesting.CreateTestCobraCommandWithFlagValue(t,
zedtesting.StringFlag{FlagName: "token", FlagValue: "t1", Changed: true},
zedtesting.StringFlag{FlagName: "certificate-path", FlagValue: testCert.Name(), Changed: true},
zedtesting.StringFlag{FlagName: "endpoint", FlagValue: "e1", Changed: true},
zedtesting.BoolFlag{FlagName: "insecure", FlagValue: true, Changed: true},
zedtesting.BoolFlag{FlagName: "no-verify-ca", FlagValue: true, Changed: true},
)

bTrue := true
bFalse := false

// cli args take precedence when defined
to, err := client.GetTokenWithCLIOverride(cmd, storage.Token{})
require.NoError(t, err)
require.True(t, to.AnyValue())
require.Equal(t, "t1", to.APIToken)
require.Equal(t, "e1", to.Endpoint)
require.Equal(t, []byte("hi"), to.CACert)
require.Equal(t, &bTrue, to.Insecure)
require.Equal(t, &bTrue, to.NoVerifyCA)

// storage token takes precedence when defined
cmd = zedtesting.CreateTestCobraCommandWithFlagValue(t,
zedtesting.StringFlag{FlagName: "token", FlagValue: "", Changed: false},
zedtesting.StringFlag{FlagName: "certificate-path", FlagValue: "", Changed: false},
zedtesting.StringFlag{FlagName: "endpoint", FlagValue: "", Changed: false},
zedtesting.BoolFlag{FlagName: "insecure", FlagValue: true, Changed: false},
zedtesting.BoolFlag{FlagName: "no-verify-ca", FlagValue: true, Changed: false},
)
to, err = client.GetTokenWithCLIOverride(cmd, storage.Token{
APIToken: "t2",
Endpoint: "e2",
CACert: []byte("bye"),
Insecure: &bFalse,
NoVerifyCA: &bFalse,
})
require.NoError(t, err)
require.True(t, to.AnyValue())
require.Equal(t, "t2", to.APIToken)
require.Equal(t, "e2", to.Endpoint)
require.Equal(t, []byte("bye"), to.CACert)
require.Equal(t, &bFalse, to.Insecure)
require.Equal(t, &bFalse, to.NoVerifyCA)
}
22 changes: 3 additions & 19 deletions internal/cmd/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"strings"

v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
"github.com/authzed/authzed-go/v1"
"github.com/authzed/spicedb/pkg/schemadsl/compiler"
"github.com/authzed/spicedb/pkg/schemadsl/generator"
"github.com/authzed/spicedb/pkg/schemadsl/input"
Expand All @@ -23,7 +22,6 @@ import (
"github.com/authzed/zed/internal/client"
"github.com/authzed/zed/internal/commands"
"github.com/authzed/zed/internal/console"
"github.com/authzed/zed/internal/storage"
)

func registerAdditionalSchemaCmds(schemaCmd *cobra.Command) {
Expand Down Expand Up @@ -52,28 +50,14 @@ var schemaCopyCmd = &cobra.Command{
RunE: schemaCopyCmdFunc,
}

// TODO(jschorr): support this in the client package
func clientForContext(cmd *cobra.Command, contextName string, secretStore storage.SecretStore) (*authzed.Client, error) {
token, err := storage.GetToken(contextName, secretStore)
if err != nil {
return nil, err
}
log.Trace().Interface("token", token).Send()

dialOpts, err := client.DialOptsFromFlags(cmd, token)
if err != nil {
return nil, err
}
return authzed.NewClient(token.Endpoint, dialOpts...)
}

func schemaCopyCmdFunc(cmd *cobra.Command, args []string) error {
_, secretStore := client.DefaultStorage()
srcClient, err := clientForContext(cmd, args[0], secretStore)
srcClient, err := client.NewClientForContext(cmd, args[0], secretStore)
if err != nil {
return err
}
destClient, err := clientForContext(cmd, args[1], secretStore)

destClient, err := client.NewClientForContext(cmd, args[1], secretStore)
if err != nil {
return err
}
Expand Down
22 changes: 2 additions & 20 deletions internal/cmd/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@ import (
"github.com/gookit/color"
"github.com/jzelinskie/cobrautil/v2"
"github.com/mattn/go-isatty"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"

"github.com/authzed/zed/internal/client"
"github.com/authzed/zed/internal/console"
"github.com/authzed/zed/internal/storage"
)

func versionCmdFunc(cmd *cobra.Command, _ []string) error {
Expand All @@ -26,14 +24,9 @@ func versionCmdFunc(cmd *cobra.Command, _ []string) error {

includeRemoteVersion := cobrautil.MustGetBool(cmd, "include-remote-version")
hasContext := false
configStore, secretStore := client.DefaultStorage()
if includeRemoteVersion {
_, err := storage.DefaultToken(
cobrautil.MustGetString(cmd, "endpoint"),
cobrautil.MustGetString(cmd, "token"),
configStore,
secretStore,
)
configStore, secretStore := client.DefaultStorage()
_, err := client.GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore)
hasContext = err == nil
}

Expand All @@ -45,17 +38,6 @@ func versionCmdFunc(cmd *cobra.Command, _ []string) error {
console.Println(cobrautil.UsageVersion("zed", cobrautil.MustGetBool(cmd, "include-deps")))

if hasContext && includeRemoteVersion {
token, err := storage.DefaultToken(
cobrautil.MustGetString(cmd, "endpoint"),
cobrautil.MustGetString(cmd, "token"),
configStore,
secretStore,
)
if err != nil {
return err
}
log.Trace().Interface("token", token).Send()

client, err := client.NewClient(cmd)
if err != nil {
return err
Expand Down
Loading

0 comments on commit af2bed5

Please sign in to comment.