diff --git a/internal/client/client.go b/internal/client/client.go index 7fd8a61..b236267 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -12,7 +12,7 @@ import ( "github.com/mitchellh/go-homedir" "github.com/rs/zerolog/log" "github.com/spf13/cobra" - grpc "google.golang.org/grpc" + "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" zgrpcutil "github.com/authzed/zed/internal/grpcutil" @@ -28,20 +28,17 @@ type Client interface { } // NewClient defines an (overridable) means of creating a new client. -var NewClient = newGRPCClient +var ( + NewClient = newClientForCurrentContext + NewClientForContext = newClientForContext +) -func newGRPCClient(cmd *cobra.Command) (Client, error) { +func newClientForCurrentContext(cmd *cobra.Command) (Client, error) { configStore, secretStore := DefaultStorage() - token, err := storage.DefaultToken( - cobrautil.MustGetString(cmd, "endpoint"), - cobrautil.MustGetString(cmd, "token"), - configStore, - secretStore, - ) + token, err := GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore) if err != nil { return nil, err } - log.Trace().Interface("token", token).Send() dialOpts, err := DialOptsFromFlags(cmd, token) if err != nil { @@ -56,28 +53,115 @@ func newGRPCClient(cmd *cobra.Command) (Client, error) { return client, err } +func newClientForContext(cmd *cobra.Command, contextName string, secretStore storage.SecretStore) (*authzed.Client, error) { + currentToken, err := storage.GetToken(contextName, secretStore) + if err != nil { + return nil, err + } + + token, err := GetTokenWithCLIOverride(cmd, currentToken) + if err != nil { + return nil, err + } + + dialOpts, err := DialOptsFromFlags(cmd, token) + if err != nil { + return nil, err + } + + return authzed.NewClient(token.Endpoint, dialOpts...) +} + +// GetCurrentTokenWithCLIOverride returns the current token, but overridden by any parameter specified via CLI args +func GetCurrentTokenWithCLIOverride(cmd *cobra.Command, configStore storage.ConfigStore, secretStore storage.SecretStore) (storage.Token, error) { + token, err := storage.CurrentToken( + configStore, + secretStore, + ) + if err != nil { + return storage.Token{}, err + } + + return GetTokenWithCLIOverride(cmd, token) +} + +// GetTokenWithCLIOverride returns the provided token, but overridden by any parameter specified explicitly via command +// flags +func GetTokenWithCLIOverride(cmd *cobra.Command, token storage.Token) (storage.Token, error) { + overrideToken, err := tokenFromCli(cmd) + if err != nil { + return storage.Token{}, err + } + + result, err := storage.TokenWithOverride( + overrideToken, + token, + ) + if err != nil { + return storage.Token{}, err + } + + log.Trace().Bool("context-override-via-cli", overrideToken.AnyValue()).Interface("context", result).Send() + return result, nil +} + +func tokenFromCli(cmd *cobra.Command) (storage.Token, error) { + certPath := cobrautil.MustGetStringExpanded(cmd, "certificate-path") + var certBytes []byte + var err error + if certPath != "" { + certBytes, err = os.ReadFile(certPath) + if err != nil { + return storage.Token{}, fmt.Errorf("failed to read ceritficate: %w", err) + } + } + + explicitInsecure := cmd.Flags().Changed("insecure") + var notSecure *bool + if explicitInsecure { + i := cobrautil.MustGetBool(cmd, "insecure") + notSecure = &i + } + + explicitNoVerifyCA := cmd.Flags().Changed("no-verify-ca") + var notVerifyCA *bool + if explicitNoVerifyCA { + nvc := cobrautil.MustGetBool(cmd, "no-verify-ca") + notVerifyCA = &nvc + } + overrideToken := storage.Token{ + APIToken: cobrautil.MustGetString(cmd, "token"), + Endpoint: cobrautil.MustGetString(cmd, "endpoint"), + Insecure: notSecure, + NoVerifyCA: notVerifyCA, + CACert: certBytes, + } + return overrideToken, nil +} + // DefaultStorage returns the default configured config store and secret store. func DefaultStorage() (storage.ConfigStore, storage.SecretStore) { var home string if xdg := os.Getenv("XDG_CONFIG_HOME"); xdg != "" { home = filepath.Join(xdg, "zed") } else { - homedir, _ := homedir.Dir() - home = filepath.Join(homedir, ".zed") + hmdir, _ := homedir.Dir() + home = filepath.Join(hmdir, ".zed") } return &storage.JSONConfigStore{ConfigPath: home}, &storage.KeychainSecretStore{ConfigPath: home} } -func certOption(cmd *cobra.Command, token storage.Token) (opt grpc.DialOption, err error) { +func certOption(token storage.Token) (opt grpc.DialOption, err error) { verification := grpcutil.VerifyCA - if cobrautil.MustGetBool(cmd, "no-verify-ca") || token.HasNoVerifyCA() { + if token.HasNoVerifyCA() { verification = grpcutil.SkipVerifyCA } if certBytes, ok := token.Certificate(); ok { return grpcutil.WithCustomCertBytes(verification, certBytes) } + return grpcutil.WithSystemCerts(verification) } @@ -96,12 +180,12 @@ func DialOptsFromFlags(cmd *cobra.Command, token storage.Token) ([]grpc.DialOpti grpc.WithChainStreamInterceptor(zgrpcutil.StreamLogDispatchTrailers), } - if cobrautil.MustGetBool(cmd, "insecure") || (token.IsInsecure()) { + if token.IsInsecure() { opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) opts = append(opts, grpcutil.WithInsecureBearerToken(token.APIToken)) } else { opts = append(opts, grpcutil.WithBearerToken(token.APIToken)) - certOpt, err := certOption(cmd, token) + certOpt, err := certOption(token) if err != nil { return nil, fmt.Errorf("failed to configure TLS cert: %w", err) } diff --git a/internal/client/client_test.go b/internal/client/client_test.go new file mode 100644 index 0000000..e1a0d92 --- /dev/null +++ b/internal/client/client_test.go @@ -0,0 +1,62 @@ +package client_test + +import ( + "os" + "testing" + + "github.com/authzed/zed/internal/client" + "github.com/authzed/zed/internal/storage" + zedtesting "github.com/authzed/zed/internal/testing" + + "github.com/stretchr/testify/require" +) + +func TestGetTokenWithCLIOverride(t *testing.T) { + testCert, err := os.CreateTemp("", "") + require.NoError(t, err) + _, err = testCert.Write([]byte("hi")) + require.NoError(t, err) + cmd := zedtesting.CreateTestCobraCommandWithFlagValue(t, + zedtesting.StringFlag{FlagName: "token", FlagValue: "t1", Changed: true}, + zedtesting.StringFlag{FlagName: "certificate-path", FlagValue: testCert.Name(), Changed: true}, + zedtesting.StringFlag{FlagName: "endpoint", FlagValue: "e1", Changed: true}, + zedtesting.BoolFlag{FlagName: "insecure", FlagValue: true, Changed: true}, + zedtesting.BoolFlag{FlagName: "no-verify-ca", FlagValue: true, Changed: true}, + ) + + bTrue := true + bFalse := false + + // cli args take precedence when defined + to, err := client.GetTokenWithCLIOverride(cmd, storage.Token{}) + require.NoError(t, err) + require.True(t, to.AnyValue()) + require.Equal(t, "t1", to.APIToken) + require.Equal(t, "e1", to.Endpoint) + require.Equal(t, []byte("hi"), to.CACert) + require.Equal(t, &bTrue, to.Insecure) + require.Equal(t, &bTrue, to.NoVerifyCA) + + // storage token takes precedence when defined + cmd = zedtesting.CreateTestCobraCommandWithFlagValue(t, + zedtesting.StringFlag{FlagName: "token", FlagValue: "", Changed: false}, + zedtesting.StringFlag{FlagName: "certificate-path", FlagValue: "", Changed: false}, + zedtesting.StringFlag{FlagName: "endpoint", FlagValue: "", Changed: false}, + zedtesting.BoolFlag{FlagName: "insecure", FlagValue: true, Changed: false}, + zedtesting.BoolFlag{FlagName: "no-verify-ca", FlagValue: true, Changed: false}, + ) + to, err = client.GetTokenWithCLIOverride(cmd, storage.Token{ + APIToken: "t2", + Endpoint: "e2", + CACert: []byte("bye"), + Insecure: &bFalse, + NoVerifyCA: &bFalse, + }) + require.NoError(t, err) + require.True(t, to.AnyValue()) + require.Equal(t, "t2", to.APIToken) + require.Equal(t, "e2", to.Endpoint) + require.Equal(t, []byte("bye"), to.CACert) + require.Equal(t, &bFalse, to.Insecure) + require.Equal(t, &bFalse, to.NoVerifyCA) +} diff --git a/internal/cmd/schema.go b/internal/cmd/schema.go index dc8439b..28b1186 100644 --- a/internal/cmd/schema.go +++ b/internal/cmd/schema.go @@ -9,7 +9,6 @@ import ( "strings" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" - "github.com/authzed/authzed-go/v1" "github.com/authzed/spicedb/pkg/schemadsl/compiler" "github.com/authzed/spicedb/pkg/schemadsl/generator" "github.com/authzed/spicedb/pkg/schemadsl/input" @@ -23,7 +22,6 @@ import ( "github.com/authzed/zed/internal/client" "github.com/authzed/zed/internal/commands" "github.com/authzed/zed/internal/console" - "github.com/authzed/zed/internal/storage" ) func registerAdditionalSchemaCmds(schemaCmd *cobra.Command) { @@ -52,28 +50,14 @@ var schemaCopyCmd = &cobra.Command{ RunE: schemaCopyCmdFunc, } -// TODO(jschorr): support this in the client package -func clientForContext(cmd *cobra.Command, contextName string, secretStore storage.SecretStore) (*authzed.Client, error) { - token, err := storage.GetToken(contextName, secretStore) - if err != nil { - return nil, err - } - log.Trace().Interface("token", token).Send() - - dialOpts, err := client.DialOptsFromFlags(cmd, token) - if err != nil { - return nil, err - } - return authzed.NewClient(token.Endpoint, dialOpts...) -} - func schemaCopyCmdFunc(cmd *cobra.Command, args []string) error { _, secretStore := client.DefaultStorage() - srcClient, err := clientForContext(cmd, args[0], secretStore) + srcClient, err := client.NewClientForContext(cmd, args[0], secretStore) if err != nil { return err } - destClient, err := clientForContext(cmd, args[1], secretStore) + + destClient, err := client.NewClientForContext(cmd, args[1], secretStore) if err != nil { return err } diff --git a/internal/cmd/version.go b/internal/cmd/version.go index 8c67b16..95d30d8 100644 --- a/internal/cmd/version.go +++ b/internal/cmd/version.go @@ -9,14 +9,12 @@ import ( "github.com/gookit/color" "github.com/jzelinskie/cobrautil/v2" "github.com/mattn/go-isatty" - "github.com/rs/zerolog/log" "github.com/spf13/cobra" "google.golang.org/grpc" "google.golang.org/grpc/metadata" "github.com/authzed/zed/internal/client" "github.com/authzed/zed/internal/console" - "github.com/authzed/zed/internal/storage" ) func versionCmdFunc(cmd *cobra.Command, _ []string) error { @@ -26,14 +24,9 @@ func versionCmdFunc(cmd *cobra.Command, _ []string) error { includeRemoteVersion := cobrautil.MustGetBool(cmd, "include-remote-version") hasContext := false - configStore, secretStore := client.DefaultStorage() if includeRemoteVersion { - _, err := storage.DefaultToken( - cobrautil.MustGetString(cmd, "endpoint"), - cobrautil.MustGetString(cmd, "token"), - configStore, - secretStore, - ) + configStore, secretStore := client.DefaultStorage() + _, err := client.GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore) hasContext = err == nil } @@ -45,17 +38,6 @@ func versionCmdFunc(cmd *cobra.Command, _ []string) error { console.Println(cobrautil.UsageVersion("zed", cobrautil.MustGetBool(cmd, "include-deps"))) if hasContext && includeRemoteVersion { - token, err := storage.DefaultToken( - cobrautil.MustGetString(cmd, "endpoint"), - cobrautil.MustGetString(cmd, "token"), - configStore, - secretStore, - ) - if err != nil { - return err - } - log.Trace().Interface("token", token).Send() - client, err := client.NewClient(cmd) if err != nil { return err diff --git a/internal/storage/config.go b/internal/storage/config.go index ef379d0..739079e 100644 --- a/internal/storage/config.go +++ b/internal/storage/config.go @@ -27,38 +27,47 @@ type ConfigStore interface { Put(Config) error } -var ErrMissingToken = errors.New("could not find token") +// TokenWithOverride returns a Token that retrieves its values from the reference Token, and has its values overridden +// any of the non-empty/non-nil values of the overrideToken. +func TokenWithOverride(overrideToken Token, referenceToken Token) (Token, error) { + insecure := referenceToken.Insecure + if overrideToken.Insecure != nil { + insecure = overrideToken.Insecure + } -// DefaultToken creates a Token from input, filling any missing values in -// with the current context's defaults. -func DefaultToken(overrideEndpoint, overrideAPIToken string, cs ConfigStore, ss SecretStore) (Token, error) { - if overrideEndpoint != "" && overrideAPIToken != "" { - return Token{ - Name: "env", - Endpoint: overrideEndpoint, - APIToken: overrideAPIToken, - }, nil + // done so that logging messages don't show nil for the resulting context + if insecure == nil { + bFalse := false + insecure = &bFalse } - token, err := CurrentToken(cs, ss) - if err != nil { - if errors.Is(err, ErrConfigNotFound) { - return Token{}, errors.New("no context found: see `zed context set --help` to setup a context or make sure to specifiy *all* context flags (--endpoint, --token and --insecure if necessary) to run without context") - } - return Token{}, err + noVerifyCA := referenceToken.NoVerifyCA + if overrideToken.NoVerifyCA != nil { + noVerifyCA = overrideToken.NoVerifyCA + } + + // done so that logging messages don't show nil for the resulting context + if noVerifyCA == nil { + bFalse := false + noVerifyCA = &bFalse + } + + caCert := referenceToken.CACert + if overrideToken.CACert != nil { + caCert = overrideToken.CACert } return Token{ - Name: token.Name, - Endpoint: stringz.DefaultEmpty(overrideEndpoint, token.Endpoint), - APIToken: stringz.DefaultEmpty(overrideAPIToken, token.APIToken), - Insecure: token.Insecure, - NoVerifyCA: token.NoVerifyCA, - CACert: token.CACert, + Name: referenceToken.Name, + Endpoint: stringz.DefaultEmpty(overrideToken.Endpoint, referenceToken.Endpoint), + APIToken: stringz.DefaultEmpty(overrideToken.APIToken, referenceToken.APIToken), + Insecure: insecure, + NoVerifyCA: noVerifyCA, + CACert: caCert, }, nil } -// CurrentToken is convenient way to obtain the CurrentToken field from the +// CurrentToken is a convenient way to obtain the CurrentToken field from the // current Config. func CurrentToken(cs ConfigStore, ss SecretStore) (Token, error) { cfg, err := cs.Get() diff --git a/internal/storage/config_test.go b/internal/storage/config_test.go new file mode 100644 index 0000000..aa65d6e --- /dev/null +++ b/internal/storage/config_test.go @@ -0,0 +1,48 @@ +package storage + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTokenWithOverride(t *testing.T) { + bTrue := true + referenceToken := Token{ + Name: "n1", + Endpoint: "e1", + APIToken: "a1", + Insecure: &bTrue, + NoVerifyCA: &bTrue, + CACert: []byte("c1"), + } + + bFalse := false + override := Token{ + Name: "n2", + Endpoint: "e2", + APIToken: "a2", + Insecure: &bFalse, + NoVerifyCA: &bFalse, + CACert: []byte("c2"), + } + + result, err := TokenWithOverride(override, referenceToken) + require.NoError(t, err) + require.Equal(t, "n1", result.Name) + require.Equal(t, "e2", result.Endpoint) + require.Equal(t, "a2", result.APIToken) + require.Equal(t, false, *result.Insecure) + require.Equal(t, false, *result.NoVerifyCA) + require.Equal(t, 0, bytes.Compare([]byte("c2"), result.CACert)) + + result, err = TokenWithOverride(Token{}, referenceToken) + require.NoError(t, err) + require.Equal(t, "n1", result.Name) + require.Equal(t, "e1", result.Endpoint) + require.Equal(t, "a1", result.APIToken) + require.Equal(t, true, *result.Insecure) + require.Equal(t, true, *result.NoVerifyCA) + require.Equal(t, 0, bytes.Compare([]byte("c1"), result.CACert)) +} diff --git a/internal/storage/secrets.go b/internal/storage/secrets.go index d29d9a9..90d1ebb 100644 --- a/internal/storage/secrets.go +++ b/internal/storage/secrets.go @@ -26,6 +26,14 @@ type Token struct { CACert []byte } +func (t Token) AnyValue() bool { + if t.Endpoint != "" || t.APIToken != "" || t.Insecure != nil || t.NoVerifyCA != nil || len(t.CACert) > 0 { + return true + } + + return false +} + func (t Token) Certificate() (cert []byte, ok bool) { if len(t.CACert) > 0 { return t.CACert, true diff --git a/internal/storage/secrets_test.go b/internal/storage/secrets_test.go new file mode 100644 index 0000000..da8dd15 --- /dev/null +++ b/internal/storage/secrets_test.go @@ -0,0 +1,19 @@ +package storage + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTokenAnyValue(t *testing.T) { + b := false + + require.False(t, Token{}.AnyValue()) + require.False(t, Token{}.AnyValue()) + require.True(t, Token{Endpoint: "foo"}.AnyValue()) + require.True(t, Token{APIToken: "foo"}.AnyValue()) + require.True(t, Token{Insecure: &b}.AnyValue()) + require.True(t, Token{NoVerifyCA: &b}.AnyValue()) + require.True(t, Token{CACert: []byte("a")}.AnyValue()) +} diff --git a/internal/testing/test_helpers.go b/internal/testing/test_helpers.go index 451b2ff..8535ba0 100644 --- a/internal/testing/test_helpers.go +++ b/internal/testing/test_helpers.go @@ -64,31 +64,37 @@ func NewTestServer(ctx context.Context, t *testing.T) server.RunnableServer { type StringFlag struct { FlagName string FlagValue string + Changed bool } type BoolFlag struct { FlagName string FlagValue bool + Changed bool } type IntFlag struct { FlagName string FlagValue int + Changed bool } type UintFlag struct { FlagName string FlagValue uint + Changed bool } type UintFlag32 struct { FlagName string FlagValue uint32 + Changed bool } type DurationFlag struct { FlagName string FlagValue time.Duration + Changed bool } func CreateTestCobraCommandWithFlagValue(t *testing.T, flagAndValues ...any) *cobra.Command { @@ -99,16 +105,22 @@ func CreateTestCobraCommandWithFlagValue(t *testing.T, flagAndValues ...any) *co switch f := flagAndValue.(type) { case StringFlag: c.Flags().String(f.FlagName, f.FlagValue, "") + c.Flag(f.FlagName).Changed = f.Changed case BoolFlag: c.Flags().Bool(f.FlagName, f.FlagValue, "") + c.Flag(f.FlagName).Changed = f.Changed case IntFlag: c.Flags().Int(f.FlagName, f.FlagValue, "") + c.Flag(f.FlagName).Changed = f.Changed case UintFlag: c.Flags().Uint(f.FlagName, f.FlagValue, "") + c.Flag(f.FlagName).Changed = f.Changed case UintFlag32: c.Flags().Uint32(f.FlagName, f.FlagValue, "") + c.Flag(f.FlagName).Changed = f.Changed case DurationFlag: c.Flags().Duration(f.FlagName, f.FlagValue, "") + c.Flag(f.FlagName).Changed = f.Changed default: t.Fatalf("unknown flag type: %T", f) }