diff --git a/cmd/spicedb/main.go b/cmd/spicedb/main.go index 1045f585d7..cabbfb3636 100644 --- a/cmd/spicedb/main.go +++ b/cmd/spicedb/main.go @@ -16,6 +16,10 @@ import ( log "github.com/authzed/spicedb/internal/logging" "github.com/authzed/spicedb/pkg/cmd" + "github.com/authzed/spicedb/pkg/cmd/cockroachdb" + "github.com/authzed/spicedb/pkg/cmd/memory" + "github.com/authzed/spicedb/pkg/cmd/mysql" + "github.com/authzed/spicedb/pkg/cmd/postgres" cmdutil "github.com/authzed/spicedb/pkg/cmd/server" "github.com/authzed/spicedb/pkg/cmd/testserver" _ "github.com/authzed/spicedb/pkg/runtime" @@ -56,6 +60,7 @@ func main() { if err != nil { log.Fatal().Err(err).Msg("failed to register datastore command") } + datastoreCmd.Hidden = true cmd.RegisterDatastoreRootFlags(datastoreCmd) rootCmd.AddCommand(datastoreCmd) @@ -74,14 +79,44 @@ func main() { cmd.RegisterMigrateFlags(migrateCmd) rootCmd.AddCommand(migrateCmd) + // Add datastore commands + rootCmd.AddGroup(&cobra.Group{ + ID: "datastores", + Title: "Datastores:", + }) + pgCmd, err := postgres.NewPostgresCommand(rootCmd.Use) + if err != nil { + log.Fatal().Err(err).Msg("failed to register serve flags") + } + rootCmd.AddCommand(pgCmd) + crdbCmd, err := cockroachdb.NewCommand(rootCmd.Use) + if err != nil { + log.Fatal().Err(err).Msg("failed to register serve flags") + } + rootCmd.AddCommand(crdbCmd) + memCmd, err := memory.NewCommand(rootCmd.Use) + if err != nil { + log.Fatal().Err(err).Msg("failed to register serve flags") + } + rootCmd.AddCommand(memCmd) + myCmd := mysql.NewCommand(rootCmd.Use) + rootCmd.AddCommand(myCmd) + // Add server commands serverConfig := cmdutil.NewConfigWithOptionsAndDefaults() serveCmd := cmd.NewServeCommand(rootCmd.Use, serverConfig) if err := cmd.RegisterServeFlags(serveCmd, serverConfig); err != nil { log.Fatal().Err(err).Msg("failed to register server flags") } + serveCmd.Hidden = true rootCmd.AddCommand(serveCmd) + // Add developer tools + rootCmd.AddGroup(&cobra.Group{ + ID: "devtools", + Title: "Developer Tools:", + }) + devtoolsCmd := cmd.NewDevtoolsCommand(rootCmd.Use) cmd.RegisterDevtoolsFlags(devtoolsCmd) rootCmd.AddCommand(devtoolsCmd) @@ -117,6 +152,11 @@ func main() { }) if err := rootCmd.Execute(); err != nil { + // Ensure that logging has been set-up before printing an error. + if preRunErr := rootCmd.PersistentPreRunE(rootCmd, nil); preRunErr != nil { + panic(preRunErr) + } + if !errors.Is(err, errParsing) { log.Err(err).Msg("terminated with errors") } diff --git a/go.mod b/go.mod index 799d6d9d06..ab4ae8f949 100644 --- a/go.mod +++ b/go.mod @@ -91,7 +91,6 @@ require ( go.opentelemetry.io/otel v1.28.0 go.opentelemetry.io/otel/sdk v1.28.0 go.opentelemetry.io/otel/trace v1.28.0 - go.uber.org/atomic v1.11.0 go.uber.org/goleak v1.3.0 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mod v0.19.0 @@ -381,6 +380,7 @@ require ( go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.22.0 // indirect go.opentelemetry.io/otel/metric v1.28.0 // indirect go.opentelemetry.io/proto/otlp v1.0.0 // indirect + go.uber.org/atomic v1.11.0 // indirect go.uber.org/automaxprocs v1.5.3 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.26.0 // indirect diff --git a/internal/graph/checkingresourcestream.go b/internal/graph/checkingresourcestream.go index 62769220a5..81d13d54a5 100644 --- a/internal/graph/checkingresourcestream.go +++ b/internal/graph/checkingresourcestream.go @@ -5,8 +5,8 @@ import ( "context" "slices" "sync" + "sync/atomic" - "go.uber.org/atomic" "golang.org/x/exp/maps" "github.com/authzed/spicedb/internal/dispatch" diff --git a/pkg/closer/closer.go b/pkg/closer/closer.go new file mode 100644 index 0000000000..78b53d422b --- /dev/null +++ b/pkg/closer/closer.go @@ -0,0 +1,46 @@ +package closer + +import ( + "io" + + "github.com/hashicorp/go-multierror" +) + +type Stack struct { + closers []func() error +} + +func (c *Stack) AddWithError(closer func() error) { + c.closers = append(c.closers, closer) +} + +func (c *Stack) AddCloser(closer io.Closer) { + if closer != nil { + c.closers = append(c.closers, closer.Close) + } +} + +func (c *Stack) AddWithoutError(closer func()) { + c.closers = append(c.closers, func() error { + closer() + return nil + }) +} + +func (c *Stack) Close() error { + var err error + // closer in reverse order how it's expected in deferred funcs + for i := len(c.closers) - 1; i >= 0; i-- { + if closerErr := c.closers[i](); closerErr != nil { + err = multierror.Append(err, closerErr) + } + } + return err +} + +func (c *Stack) CloseIfError(err error) error { + if err != nil { + return c.Close() + } + return nil +} diff --git a/pkg/cmd/cockroachdb/cockroachdb.go b/pkg/cmd/cockroachdb/cockroachdb.go new file mode 100644 index 0000000000..25337929e6 --- /dev/null +++ b/pkg/cmd/cockroachdb/cockroachdb.go @@ -0,0 +1,137 @@ +package cockroachdb + +import ( + "fmt" + "time" + + "github.com/go-logr/zerologr" + "github.com/jzelinskie/cobrautil/v2" + "github.com/jzelinskie/cobrautil/v2/cobraotel" + "github.com/spf13/cobra" + + "github.com/authzed/spicedb/internal/datastore/crdb/migrations" + "github.com/authzed/spicedb/internal/logging" + pkgcmd "github.com/authzed/spicedb/pkg/cmd" + "github.com/authzed/spicedb/pkg/cmd/server" + "github.com/authzed/spicedb/pkg/migrate" + "github.com/authzed/spicedb/pkg/releases" + "github.com/authzed/spicedb/pkg/runtime" +) + +func NewCommand(programName string) (*cobra.Command, error) { + crdbCmd := &cobra.Command{ + Use: "cockroachdb", + Aliases: []string{"cockroach", "crdb"}, + Short: "Perform operations on data stored in CockroachDB", + GroupID: "datastores", + Hidden: false, + } + migrationsCmd := NewMigrationCommand(programName) + crdbCmd.AddCommand(migrationsCmd) + + cfg := &server.Config{} + cfg.DatastoreConfig.Engine = "cockroachdb" + cfg.NamespaceCacheConfig = pkgcmd.NamespaceCacheConfig + cfg.ClusterDispatchCacheConfig = server.CacheConfig{} + cfg.DispatchCacheConfig = server.CacheConfig{} + + serveCmd := &cobra.Command{ + Use: "serve-grpc", + Short: "Serve the SpiceDB gRPC API services", + Example: pkgcmd.ServeExample(programName), + PreRunE: cobrautil.CommandStack( + cobraotel.New("spicedb", cobraotel.WithLogger(zerologr.New(&logging.Logger))).RunE(), + releases.CheckAndLogRunE(), + runtime.RunE(), + ), + RunE: pkgcmd.ServeGRPCRunE(cfg), + } + + nfs := cobrautil.NewNamedFlagSets(serveCmd) + if err := pkgcmd.RegisterCRDBDatastoreFlags(serveCmd, nfs.FlagSet("CockroachDB Datastore"), cfg); err != nil { + return nil, err + } + + postRegisterFn, err := pkgcmd.RegisterCommonServeFlags(programName, serveCmd, nfs, cfg, true) + if err != nil { + return nil, err + } + + // Flags must be registered to the command after flags are set. + nfs.AddFlagSets(serveCmd) + if err := postRegisterFn(); err != nil { + return nil, err + } + + crdbCmd.AddCommand(serveCmd) + + return crdbCmd, nil +} + +func NewMigrationCommand(programName string) *cobra.Command { + migrationsCmd := &cobra.Command{ + Use: "migrations", + Short: "Perform migrations and schema changes", + } + + headCmd := &cobra.Command{ + Use: "head", + Short: "Print the latest migration", + Args: cobra.ExactArgs(0), + RunE: func(cmd *cobra.Command, args []string) error { + head, err := migrations.CRDBMigrations.HeadRevision() + if err != nil { + return fmt.Errorf("unable to compute head revision: %w", err) + } + _, err = fmt.Println(head) + return err + }, + } + migrationsCmd.AddCommand(headCmd) + + execCmd := &cobra.Command{ + Use: "exec ", + Short: "Execute all migrations up to and including the provided migration", + Args: cobra.ExactArgs(1), + RunE: ExecMigrationRunE, + } + migrationsCmd.AddCommand(execCmd) + RegisterMigrationExecFlags(execCmd) + + return migrationsCmd +} + +func RegisterMigrationExecFlags(cmd *cobra.Command) { + cmd.Flags().String("crdb-uri", "postgres://roach:password@localhost:5432/spicedb", "connection string in URI format") + cmd.Flags().Uint64("backfill-batch-size", 1000, "batch size used when backfilling data") + cmd.Flags().Duration("timeout", 1*time.Hour, "maximum execution duration for an individual migration") +} + +func ExecMigrationRunE(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + + revision := args[0] + if revision == "head" { + head, err := migrations.CRDBMigrations.HeadRevision() + if err != nil { + return fmt.Errorf("unable to compute head revision: %w", err) + } + revision = head + } + + logging.Ctx(ctx).Info().Str("target", revision).Msg("executing migrations") + + migrationDriver, err := migrations.NewCRDBDriver(cobrautil.MustGetStringExpanded(cmd, "crdb-uri")) + if err != nil { + return fmt.Errorf("unable to create cockroachdb migration driver: %w", err) + } + + return migrate.RunMigration( + cmd.Context(), + migrationDriver, + migrations.CRDBMigrations, + revision, + cobrautil.MustGetDuration(cmd, "timeout"), + cobrautil.MustGetUint64(cmd, "backfill-batch-size"), + ) +} diff --git a/pkg/cmd/devtools.go b/pkg/cmd/devtools.go index 55b2a1b8ee..c3debe0b15 100644 --- a/pkg/cmd/devtools.go +++ b/pkg/cmd/devtools.go @@ -49,6 +49,7 @@ func NewDevtoolsCommand(programName string) *cobra.Command { Use: "serve-devtools", Short: "runs the developer tools service", Long: "Serves the authzed.api.v0.DeveloperService which is used for development tooling such as the Authzed Playground", + GroupID: "devtools", PreRunE: server.DefaultPreRunE(programName), RunE: termination.PublishError(runfunc), Args: cobra.ExactArgs(0), diff --git a/pkg/cmd/lsp.go b/pkg/cmd/lsp.go index a3ed139692..1da5cec551 100644 --- a/pkg/cmd/lsp.go +++ b/pkg/cmd/lsp.go @@ -4,13 +4,8 @@ import ( "context" "time" - "github.com/go-logr/zerologr" - "github.com/jzelinskie/cobrautil/v2" - "github.com/jzelinskie/cobrautil/v2/cobrazerolog" - "github.com/rs/zerolog" "github.com/spf13/cobra" - "github.com/authzed/spicedb/internal/logging" "github.com/authzed/spicedb/internal/lsp" "github.com/authzed/spicedb/pkg/cmd/termination" "github.com/authzed/spicedb/pkg/releases" @@ -37,17 +32,10 @@ func RegisterLSPFlags(cmd *cobra.Command, config *LSPConfig) error { func NewLSPCommand(programName string, config *LSPConfig) *cobra.Command { return &cobra.Command{ - Use: "lsp", - Short: "serve language server protocol", - PreRunE: cobrautil.CommandStack( - cobrautil.SyncViperDotEnvPreRunE(programName, "spicedb.env", zerologr.New(&logging.Logger)), - cobrazerolog.New( - cobrazerolog.WithTarget(func(logger zerolog.Logger) { - logging.SetGlobalLogger(logger) - }), - ).RunE(), - releases.CheckAndLogRunE(), - ), + Use: "lsp", + Short: "serve language server protocol", + GroupID: "devtools", + PreRunE: releases.CheckAndLogRunE(), RunE: termination.PublishError(func(cmd *cobra.Command, args []string) error { srv, err := config.Complete(cmd.Context()) if err != nil { diff --git a/pkg/cmd/memory/memory.go b/pkg/cmd/memory/memory.go new file mode 100644 index 0000000000..2d58570457 --- /dev/null +++ b/pkg/cmd/memory/memory.go @@ -0,0 +1,60 @@ +package memory + +import ( + "github.com/go-logr/zerologr" + "github.com/jzelinskie/cobrautil/v2" + "github.com/jzelinskie/cobrautil/v2/cobraotel" + "github.com/spf13/cobra" + + "github.com/authzed/spicedb/internal/logging" + pkgcmd "github.com/authzed/spicedb/pkg/cmd" + "github.com/authzed/spicedb/pkg/cmd/server" + "github.com/authzed/spicedb/pkg/releases" + "github.com/authzed/spicedb/pkg/runtime" +) + +const ( + exampleWithoutTLS = "memory serve-grpc --grpc-preshared-key secretKeyHere" + exampleWithTLS = "memory serve-grpc --grpc-preshared-key secretKeyHere --grpc-tls-cert path/to/cert --grpc-tls-key path/to/key" +) + +func NewCommand(programName string) (*cobra.Command, error) { + memCmd := &cobra.Command{ + Use: "memory", + Aliases: []string{"mem"}, + Short: "Perform operations on data stored in non-persistent memory", + GroupID: "datastores", + } + + cfg := &server.Config{} + cfg.DatastoreConfig.Engine = "memory" + cfg.NamespaceCacheConfig = pkgcmd.NamespaceCacheConfig + cfg.ClusterDispatchCacheConfig = server.CacheConfig{} + cfg.DispatchCacheConfig = server.CacheConfig{} + + serveCmd := &cobra.Command{ + Use: "serve-grpc", + Short: "Serve the SpiceDB gRPC API services", + Example: pkgcmd.ServeExample(programName, exampleWithoutTLS, exampleWithTLS), + PreRunE: cobrautil.CommandStack( + cobraotel.New("spicedb", cobraotel.WithLogger(zerologr.New(&logging.Logger))).RunE(), + releases.CheckAndLogRunE(), + runtime.RunE(), + ), + RunE: pkgcmd.ServeGRPCRunE(cfg), + } + nfs := cobrautil.NewNamedFlagSets(serveCmd) + postRegisterFn, err := pkgcmd.RegisterCommonServeFlags(programName, serveCmd, nfs, cfg, false) + if err != nil { + return nil, err + } + + // Flags must be registered to the command after flags are set. + nfs.AddFlagSets(serveCmd) + if err := postRegisterFn(); err != nil { + return nil, err + } + memCmd.AddCommand(serveCmd) + + return memCmd, nil +} diff --git a/pkg/cmd/migrate.go b/pkg/cmd/migrate.go index c0e00e68af..88f7a32cc2 100644 --- a/pkg/cmd/migrate.go +++ b/pkg/cmd/migrate.go @@ -118,6 +118,7 @@ func migrateRun(cmd *cobra.Command, args []string) error { return fmt.Errorf("cannot migrate datastore engine type: %s", datastoreEngine) } +// TODO(jzelinskie): deprecated, replace with migrate.RunMigration() func runMigration[D migrate.Driver[C, T], C any, T any]( ctx context.Context, driver D, diff --git a/pkg/cmd/mysql/mysql.go b/pkg/cmd/mysql/mysql.go new file mode 100644 index 0000000000..205d49830b --- /dev/null +++ b/pkg/cmd/mysql/mysql.go @@ -0,0 +1,109 @@ +package mysql + +import ( + "fmt" + "time" + + "github.com/jzelinskie/cobrautil/v2" + "github.com/spf13/cobra" + + "github.com/authzed/spicedb/internal/datastore/mysql/migrations" + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/migrate" +) + +func NewCommand(programName string) *cobra.Command { + crdbCmd := &cobra.Command{ + Use: "mysql", + Aliases: []string{"mariadb", "vitess"}, + Short: "Perform operations on data stored in MySQL variants", + GroupID: "datastores", + Hidden: false, + } + migrationsCmd := NewMigrationCommand(programName) + crdbCmd.AddCommand(migrationsCmd) + + return crdbCmd +} + +func NewMigrationCommand(programName string) *cobra.Command { + migrationsCmd := &cobra.Command{ + Use: "migrations", + Short: "Perform migrations and schema changes", + } + + headCmd := &cobra.Command{ + Use: "head", + Short: "Print the latest migration", + Args: cobra.ExactArgs(0), + RunE: func(cmd *cobra.Command, args []string) error { + head, err := migrations.Manager.HeadRevision() + if err != nil { + return fmt.Errorf("unable to compute head revision: %w", err) + } + _, err = fmt.Println(head) + return err + }, + } + migrationsCmd.AddCommand(headCmd) + + execCmd := &cobra.Command{ + Use: "exec ", + Short: "Execute all migrations up to and including the provided migration", + Args: cobra.ExactArgs(1), + RunE: ExecMigrationRunE, + } + migrationsCmd.AddCommand(execCmd) + RegisterMigrationExecFlags(execCmd) + + return migrationsCmd +} + +func RegisterMigrationExecFlags(cmd *cobra.Command) { + cmd.Flags().String("mysql-uri", "mysql://mysql:password@localhost:5432/spicedb", "connection string in URI format") + cmd.Flags().String("mysql-table-prefix", "", "prefix to include in all table names") + cmd.Flags().Uint64("backfill-batch-size", 1000, "batch size used when backfilling data") + cmd.Flags().Duration("timeout", 1*time.Hour, "maximum execution duration for an individual migration") +} + +func ExecMigrationRunE(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + + revision := args[0] + if revision == "head" { + head, err := migrations.Manager.HeadRevision() + if err != nil { + return fmt.Errorf("unable to compute head revision: %w", err) + } + revision = head + } + + var credsProvider datastore.CredentialsProvider + if providerName := cobrautil.MustGetString(cmd, "datastore-credentials-provider-name"); providerName != "" { + var err error + credsProvider, err = datastore.NewCredentialsProvider(ctx, providerName) + if err != nil { + return err + } + } + + migrationDriver, err := migrations.NewMySQLDriverFromDSN( + cobrautil.MustGetStringExpanded(cmd, "mysql-uri"), + cobrautil.MustGetStringExpanded(cmd, "mysql-table-prefix"), + credsProvider, + ) + if err != nil { + return fmt.Errorf("unable to create mysql migration driver: %w", err) + } + + log.Ctx(ctx).Info().Str("target", revision).Msg("executing migrations") + return migrate.RunMigration( + ctx, + migrationDriver, + migrations.Manager, + revision, + cobrautil.MustGetDuration(cmd, "timeout"), + cobrautil.MustGetUint64(cmd, "backfill-batch-size"), + ) +} diff --git a/pkg/cmd/postgres/postgres.go b/pkg/cmd/postgres/postgres.go new file mode 100644 index 0000000000..f441e92073 --- /dev/null +++ b/pkg/cmd/postgres/postgres.go @@ -0,0 +1,179 @@ +package postgres + +import ( + "fmt" + "time" + + "github.com/go-logr/zerologr" + "github.com/jzelinskie/cobrautil/v2" + "github.com/jzelinskie/cobrautil/v2/cobraotel" + "github.com/spf13/cobra" + + "github.com/authzed/spicedb/internal/datastore/postgres" + "github.com/authzed/spicedb/internal/datastore/postgres/migrations" + "github.com/authzed/spicedb/internal/logging" + pkgcmd "github.com/authzed/spicedb/pkg/cmd" + "github.com/authzed/spicedb/pkg/cmd/server" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/migrate" + "github.com/authzed/spicedb/pkg/releases" + "github.com/authzed/spicedb/pkg/runtime" +) + +const ( + exampleWithoutTLS = "pg serve-grpc --grpc-preshared-key secretKeyHere --pg-uri postgres://postgres:password@localhost:5432" +) + +func NewPostgresCommand(programName string) (*cobra.Command, error) { + pgCmd := &cobra.Command{ + Use: "postgres", + Aliases: []string{"pg", "postgresql"}, + Short: "Perform operations on data stored in PostgreSQL", + GroupID: "datastores", + Hidden: false, + } + migrationsCmd := NewMigrationCommand(programName) + pgCmd.AddCommand(migrationsCmd) + + cfg := &server.Config{} + cfg.DatastoreConfig.Engine = "postgres" + cfg.NamespaceCacheConfig = pkgcmd.NamespaceCacheConfig + cfg.ClusterDispatchCacheConfig = server.CacheConfig{} + cfg.DispatchCacheConfig = server.CacheConfig{} + + serveCmd := &cobra.Command{ + Use: "serve-grpc", + Short: "Serve the SpiceDB gRPC API services", + Example: pkgcmd.ServeExample(programName, exampleWithoutTLS), + PreRunE: cobrautil.CommandStack( + cobraotel.New("spicedb", cobraotel.WithLogger(zerologr.New(&logging.Logger))).RunE(), + releases.CheckAndLogRunE(), + runtime.RunE(), + ), + RunE: pkgcmd.ServeGRPCRunE(cfg), + } + + nfs := cobrautil.NewNamedFlagSets(serveCmd) + if err := pkgcmd.RegisterPostgresDatastoreFlags(serveCmd, nfs.FlagSet("Postgres Datastore"), cfg); err != nil { + return nil, err + } + + postRegisterFn, err := pkgcmd.RegisterCommonServeFlags(programName, serveCmd, nfs, cfg, true) + if err != nil { + return nil, err + } + + // Flags must be registered to the command after flags are set. + nfs.AddFlagSets(serveCmd) + if err := postRegisterFn(); err != nil { + return nil, err + } + + pgCmd.AddCommand(serveCmd) + + return pgCmd, nil +} + +func NewMigrationCommand(programName string) *cobra.Command { + migrationsCmd := &cobra.Command{ + Use: "migrations", + Short: "Perform migrations and schema changes", + } + + headCmd := &cobra.Command{ + Use: "head", + Short: "Print the latest migration", + Args: cobra.ExactArgs(0), + RunE: func(cmd *cobra.Command, args []string) error { + head, err := migrations.DatabaseMigrations.HeadRevision() + if err != nil { + return fmt.Errorf("unable to compute head revision: %w", err) + } + _, err = fmt.Println(head) + return err + }, + } + migrationsCmd.AddCommand(headCmd) + + execCmd := &cobra.Command{ + Use: "exec ", + Short: "Execute all migrations up to and including the provided migration", + Args: cobra.ExactArgs(1), + RunE: ExecMigrationRunE, + } + migrationsCmd.AddCommand(execCmd) + RegisterMigrationExecFlags(execCmd) + + repairCmd := &cobra.Command{ + Use: "repair-txids", + Short: "Fast-fowards the Postgres txid counter (required for migrating to new instances)", + Args: cobra.ExactArgs(0), + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + ds, err := postgres.NewPostgresDatastore(ctx, cobrautil.MustGetStringExpanded(cmd, "pg-uri")) + if err != nil { + return fmt.Errorf("failed to create datastore: %w", err) + } + repairable := datastore.UnwrapAs[datastore.RepairableDatastore](ds) + if repairable == nil { + return fmt.Errorf("datastore of type %T does not support the repair operation", ds) + } + + start := time.Now() + if err := repairable.Repair(ctx, "transaction-ids", true); err != nil { + return err + } + repairDuration := time.Since(start) + + logging.Ctx(ctx).Info().Dur("duration", repairDuration).Msg("datastore repair completed") + return nil + }, + } + repairCmd.Flags().String("pg-uri", "postgres://postgres:password@localhost:5432/spicedb", "connection string in URI format") + migrationsCmd.AddCommand(repairCmd) + + return migrationsCmd +} + +func RegisterMigrationExecFlags(cmd *cobra.Command) { + cmd.Flags().String("pg-uri", "postgres://postgres:password@localhost:5432/spicedb", "connection string in URI format") + cmd.Flags().Uint64("backfill-batch-size", 1000, "batch size used when backfilling data") + cmd.Flags().Duration("timeout", 1*time.Hour, "maximum execution duration for an individual migration") +} + +func ExecMigrationRunE(cmd *cobra.Command, args []string) error { + revision := args[0] + if revision == "head" { + head, err := migrations.DatabaseMigrations.HeadRevision() + if err != nil { + return fmt.Errorf("unable to compute head revision: %w", err) + } + revision = head + } + + logging.Ctx(cmd.Context()).Info().Str("target", revision).Msg("executing migrations") + + var credentialsProvider datastore.CredentialsProvider + credentialsProviderName := cobrautil.MustGetString(cmd, "datastore-credentials-provider-name") + if credentialsProviderName != "" { + var err error + credentialsProvider, err = datastore.NewCredentialsProvider(cmd.Context(), credentialsProviderName) + if err != nil { + return err + } + } + + migrationDriver, err := migrations.NewAlembicPostgresDriver(cmd.Context(), cobrautil.MustGetStringExpanded(cmd, "pg-uri"), credentialsProvider) + if err != nil { + return fmt.Errorf("unable to create postgres migration driver: %w", err) + } + + return migrate.RunMigration( + cmd.Context(), + migrationDriver, + migrations.DatabaseMigrations, + revision, + cobrautil.MustGetDuration(cmd, "timeout"), + cobrautil.MustGetUint64(cmd, "backfill-batch-size"), + ) +} diff --git a/pkg/cmd/root.go b/pkg/cmd/root.go index 19cefb8c24..0fec34ab28 100644 --- a/pkg/cmd/root.go +++ b/pkg/cmd/root.go @@ -3,16 +3,14 @@ package cmd import ( "fmt" + "github.com/go-logr/zerologr" "github.com/jzelinskie/cobrautil/v2" - "github.com/jzelinskie/cobrautil/v2/cobraotel" "github.com/jzelinskie/cobrautil/v2/cobrazerolog" + "github.com/rs/zerolog" "github.com/spf13/cobra" log "github.com/authzed/spicedb/internal/logging" "github.com/authzed/spicedb/pkg/cmd/server" - "github.com/authzed/spicedb/pkg/cmd/termination" - "github.com/authzed/spicedb/pkg/releases" - "github.com/authzed/spicedb/pkg/runtime" ) func RegisterRootFlags(cmd *cobra.Command) error { @@ -21,17 +19,6 @@ func RegisterRootFlags(cmd *cobra.Command) error { if err := zl.RegisterFlagCompletion(cmd); err != nil { return fmt.Errorf("failed to register zerolog flag completion: %w", err) } - - ot := cobraotel.New(cmd.Use) - ot.RegisterFlags(cmd.PersistentFlags()) - if err := ot.RegisterFlagCompletion(cmd); err != nil { - return fmt.Errorf("failed to register otel flag completion: %w", err) - } - - releases.RegisterFlags(cmd.PersistentFlags()) - termination.RegisterFlags(cmd.PersistentFlags()) - runtime.RegisterFlags(cmd.PersistentFlags()) - return nil } @@ -51,5 +38,14 @@ func NewRootCommand(programName string) *cobra.Command { Example: server.ServeExample(programName), SilenceErrors: true, SilenceUsage: true, + PersistentPreRunE: cobrautil.CommandStack( + cobrautil.SyncViperDotEnvPreRunE(programName, "spicedb.env", zerologr.New(&log.Logger)), + cobrazerolog.New( + cobrazerolog.WithTarget(func(logger zerolog.Logger) { + log.SetGlobalLogger(logger) + }), + cobrazerolog.WithPreRunLevel(zerolog.DebugLevel), + ).RunE(), + ), } } diff --git a/pkg/cmd/serve_new.go b/pkg/cmd/serve_new.go new file mode 100644 index 0000000000..d78127dc68 --- /dev/null +++ b/pkg/cmd/serve_new.go @@ -0,0 +1,260 @@ +package cmd + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/jzelinskie/cobrautil/v2" + "github.com/jzelinskie/cobrautil/v2/cobraotel" + "github.com/jzelinskie/stringz" + "github.com/spf13/cobra" + "github.com/spf13/pflag" + + "github.com/authzed/spicedb/internal/dispatch/graph" + "github.com/authzed/spicedb/internal/telemetry" + "github.com/authzed/spicedb/pkg/cmd/server" + "github.com/authzed/spicedb/pkg/cmd/termination" + "github.com/authzed/spicedb/pkg/releases" + "github.com/authzed/spicedb/pkg/runtime" +) + +var NamespaceCacheConfig = server.CacheConfig{ + Name: "namespace", + Enabled: true, + Metrics: true, + NumCounters: 1_000, + MaxCost: "32MiB", +} + +// ServeExample creates the example usage string with the provided program name and examples +func ServeExample(programName string, examples ...string) string { + formatted := make([]string, 0, len(examples)) + for _, example := range examples { + formatted = append(formatted, fmt.Sprintf(" %s %s", programName, example)) + } + + return stringz.Join("\n", formatted...) +} + +func ServeGRPCRunE(cfg *server.Config) cobrautil.CobraRunFunc { + return termination.PublishError(func(cmd *cobra.Command, args []string) error { + // Workaround for the flags that have been reworked. + cfg.DisableVersionResponse = !cobrautil.MustGetBool(cmd, "api-version-response") + cfg.DisableV1SchemaAPI = !cobrautil.MustGetBool(cmd, "api-schema-v1-enabled") + cfg.GRPCServer.Network, cfg.GRPCServer.Address = ParseCombinedGRPCURI(cobrautil.MustGetStringExpanded(cmd, "grpc-addr")) + cfg.DatastoreConfig.EnableDatastoreMetrics = true + cfg.DatastoreConfig.GCInterval = 180000 + cfg.DatastoreConfig.GCMaxOperationTime = 60000 + cfg.DatastoreConfig.GCWindow = 86400000 + cfg.DatastoreConfig.LegacyFuzzing = -1 + cfg.NamespaceCacheConfig = NamespaceCacheConfig + + server, err := cfg.Complete(cmd.Context()) + if err != nil { + return err + } + signalctx := SignalContextWithGracePeriod( + context.Background(), + cfg.ShutdownGracePeriod, + ) + return server.Run(signalctx) + }) +} + +func RegisterCommonServeFlags(programName string, cmd *cobra.Command, nfs *cobrautil.NamedFlagSets, cfg *server.Config, includeRemoteDispatch bool) (func() error, error) { + if err := RegisterGRPCFlags(nfs.FlagSet("gRPC"), cfg); err != nil { + return nil, err + } + + if err := RegisterAPIFlags(nfs.FlagSet("API"), cfg); err != nil { + return nil, err + } + + dispatchFlags := nfs.FlagSet("Dispatch") + if err := RegisterDispatchFlags(dispatchFlags, cfg); err != nil { + return nil, err + } + + if includeRemoteDispatch { + if err := RegisterRemoteDispatchFlags(dispatchFlags, cfg); err != nil { + return nil, err + } + } + + if err := RegisterDatastoreFlags(nfs.FlagSet("Datastore"), cfg, "memory"); err != nil { + return nil, err + } + + RegisterSeedFlags(nfs.FlagSet("Data Seeding"), cfg) + + obsFlags := nfs.FlagSet("Observability") + if err := RegisterMetricsFlags(obsFlags, cfg); err != nil { + return nil, err + } + otel := cobraotel.New(programName) + otel.RegisterFlags(obsFlags) + runtime.RegisterFlags(obsFlags) // TODO(jzelinskie): hide these? + + miscFlags := nfs.FlagSet("Miscellaneous") + termination.RegisterFlags(miscFlags) + releases.RegisterFlags(miscFlags) + + return func() error { + if err := cmd.MarkFlagRequired("grpc-preshared-key"); err != nil { + return err + } + if err := otel.RegisterFlagCompletion(cmd); err != nil { + return fmt.Errorf("failed to register otel flag completion: %w", err) + } + return nil + }, nil +} + +func RegisterCRDBDatastoreFlags(cmd *cobra.Command, flags *pflag.FlagSet, cfg *server.Config) error { + flags.DurationVar(&cfg.DatastoreConfig.ConnectRate, "crdb-connect-rate", 100*time.Millisecond, "max rate for new establishing new connections") + flags.BoolVar(&cfg.DatastoreConfig.EnableConnectionBalancing, "crdb-conn-balance-enabled", true, "balance connections across discoverable database instances") + flags.DurationVar(&cfg.DatastoreConfig.FollowerReadDelay, "crdb-follower-read-delay", 4_800*time.Millisecond, "time subtracted from revision timestamps to ensure they are beyond any replication delay") + + return cobrautil.MarkFlagsHidden(flags, "crdb-conn-balance-enabled") +} + +func RegisterPostgresDatastoreFlags(cmd *cobra.Command, flags *pflag.FlagSet, cfg *server.Config) error { + flags.StringVar(&cfg.DatastoreConfig.URI, "pg-uri", "postgres://postgres:password@localhost:5432", "connection string used to connect to PostgreSQL") + flags.StringVar(&cfg.DatastoreConfig.CredentialsProviderName, "pg-credential-provider", "", "retrieve credentials from the environment") // TODO(jzelinskie): we should just detect and use them without any flag + flags.DurationVar(&cfg.DatastoreConfig.GCInterval, "pg-gc-interval", 3*time.Minute, "frequency with which garbage collection runs") + flags.DurationVar(&cfg.DatastoreConfig.GCMaxOperationTime, "pg-gc-timeout", time.Minute, "max duration for a garbage collection pass") + flags.DurationVar(&cfg.DatastoreConfig.GCWindow, "pg-gc-window", 24*time.Hour, "max age for a revision before it can be garbage collected") + return nil +} + +func RegisterGRPCFlags(flags *pflag.FlagSet, cfg *server.Config) error { + flags.StringVar(&cfg.GRPCServer.Address, "grpc-addr", ":50051", "address to listen on") + flags.DurationVar(&cfg.GRPCServer.MaxConnAge, "grpc-conn-age-limit", 30*time.Second, "max duration a connection should live") + flags.BoolVar(&cfg.GRPCServer.Enabled, "grpc-enabled", true, "enable the gRPC server") + // flags.StringSliceVar(&cfg.GRPCServer.GatewayAllowedOrigins, "grpc-gateway-allowed-origins", "CORS origins for the gRPC REST gateway") + // flags.BoolVar(&cfg. "grpc-gateway-enabled", false, "enable the gRPC REST gateway") + flags.Uint32Var(&cfg.GRPCServer.MaxWorkers, "grpc-workers", 0, "number of workers for this server (0 for 1/request)") + flags.StringVar(&cfg.GRPCServer.TLSCertPath, "grpc-tls-cert", "", "local path to the TLS certificate") + flags.StringVar(&cfg.GRPCServer.TLSKeyPath, "grpc-tls-key", "", "local path to the TLS key") + flags.StringSliceVar(&cfg.PresharedSecureKey, "grpc-preshared-key", []string{}, "preshared key(s) for authenticating requests") + + return cobrautil.MarkFlagsHidden(flags, "grpc-conn-age-limit", "grpc-workers") +} + +func RegisterRemoteDispatchFlags(flags *pflag.FlagSet, cfg *server.Config) error { + flags.BoolVar(&cfg.DispatchServer.Enabled, "dispatch-remote-enabled", true, "serve and request dispatches to/from other instances of SpiceDB") + flags.Uint32Var(&cfg.DispatchServer.MaxWorkers, "dispatch-remote-workers", 0, "number of workers for this server (0 for 1/request)") + flags.StringVar(&cfg.DispatchServer.Address, "dispatch-remote-addr", ":50053", "TODO") + flags.StringVar(&cfg.DispatchServer.TLSCertPath, "dispatch-remote-tls-cert", "", "TODO") + flags.StringVar(&cfg.DispatchServer.TLSKeyPath, "dispatch-remote-tls-key", "", "TODO") + flags.DurationVar(&cfg.DispatchServer.MaxConnAge, "dispatch-remote-conn-age-limit", 30*time.Second, "max duration a connection should live") + flags.DurationVar(&cfg.DispatchUpstreamTimeout, "dispatch-remote-timeout", time.Second, "max duration for a single dispatch") + flags.StringVar(&cfg.DispatchUpstreamCAPath, "dispatch-remote-ca", "", "local path to the certificate authority used to connect for remote dispatching") + + flags.Uint16Var(&cfg.DispatchHashringReplicationFactor, "dispatch-hashring-replication", 1000, "replication factor of the consistent hash") + flags.Uint8Var(&cfg.DispatchHashringSpread, "dispatch-hashring-spread", 1, "spread of the consistent hash") + + return cobrautil.MarkFlagsHidden( + flags, + "dispatch-remote-conn-age-limit", + "dispatch-remote-workers", + "dispatch-hashring-replication", + "dispatch-hashring-spread", + ) +} + +func RegisterDispatchFlags(flags *pflag.FlagSet, cfg *server.Config) error { + // flags.StringVar("dispatch-cache-limit", "10GiB", "TODO") + // flags.StringVar("dispatch-local-cache-limit", "10GiB", "TODO") + // flags.StringVar("dispatch-remote-cache-limit", "1GiB", "TODO") + + flags.Uint32Var(&cfg.DispatchMaxDepth, "dispatch-depth-limit", 50, "max dispatches per request") + flags.Uint16Var(&cfg.GlobalDispatchConcurrencyLimit, "dispatch-concurrency-limit", 50, "max goroutines created per CheckPermission dispatch") + + return cobrautil.MarkFlagsHidden(flags, "dispatch-concurrency-limit") +} + +func DispatchFlags(cmd *cobra.Command, cfg *server.Config) { + limit := cobrautil.MustGetUint16(cmd, "dispatch-concurrency-limit") + cfg.DispatchConcurrencyLimits = graph.ConcurrencyLimits{ + Check: limit, + ReachableResources: limit, + LookupResources: limit, + LookupSubjects: limit, + } +} + +func RegisterSeedFlags(flags *pflag.FlagSet, cfg *server.Config) { + flags.StringSliceVar(&cfg.DatastoreConfig.BootstrapFiles, "seed", nil, "local path to YAML-formatted schema and relationships file") + flags.BoolVar(&cfg.DatastoreConfig.BootstrapOverwrite, "seed-overwrite", false, "overwrite any existing data with the seed data") + flags.DurationVar(&cfg.DatastoreConfig.BootstrapTimeout, "seed-timeout", 10*time.Second, "max duration writing seed data") +} + +func RegisterAPIFlags(flags *pflag.FlagSet, cfg *server.Config) error { + // flags.Bool("api-experimental-enabled", true, "serve experimental APIs") + flags.Bool("api-schema-v1-enabled", true, "serve the v1 schema API") + flags.Bool("api-version-response", true, "expose the version over the API") + flags.BoolVar(&cfg.DatastoreConfig.ReadOnly, "api-readonly", false, "prevent any data modifications") + flags.BoolVar(&cfg.SchemaPrefixesRequired, "api-schema-prefix-required", false, "require prefixes on all object definitions") + + flags.Uint16Var(&cfg.MaximumPreconditionCount, "api-preconditions-limit", 1000, "max preconditions allowed per write and delete request") + flags.Uint16Var(&cfg.MaximumUpdatesPerWrite, "api-updates-limit", 1000, "max updates allowed per write request") + flags.DurationVar(&cfg.StreamingAPITimeout, "api-stream-timeout", 30*time.Second, "max duration between stream responses") + + flags.DurationVar(&cfg.WatchHeartbeat, "api-watch-heartbeat", 0, "watch API heartbeat interval (defaults to the datastore min)") + flags.Uint16("api-watch-buffer-size", 0, "number of watch responses buffered in memory") + + flags.DurationVar(&cfg.DatastoreConfig.RevisionQuantization, "api-quantization-interval", 5*time.Second, "boundary interval with which to round the revision") + flags.Float64Var(&cfg.DatastoreConfig.MaxRevisionStalenessPercent, "api-quantization-staleness", 0.1, "percentage of quantization interval where stale revisions can be preferred") + flags.IntVar(&cfg.MaxCaveatContextSize, "api-caveat-context-limit", 4096, "max number of bytes for caveat context per request (<=0 is no limit") + + return cobrautil.MarkFlagsHidden( + flags, + "api-schema-v1-enabled", + "api-version-response", + "api-watch-heartbeat", + "api-watch-buffer-size", + "api-caveat-context-limit", + ) +} + +func RegisterMetricsFlags(flags *pflag.FlagSet, cfg *server.Config) error { + flags.StringVar(&cfg.MetricsAPI.HTTPAddress, "metrics-addr", ":9090", "address to listen on for serving Prometheus metrics") + flags.BoolVar(&cfg.MetricsAPI.HTTPEnabled, "metrics-enabled", true, "enable the metrics server") + flags.StringVar(&cfg.MetricsAPI.HTTPTLSCertPath, "metrics-tls-cert", "", "local path to the TLS certificate") + flags.StringVar(&cfg.MetricsAPI.HTTPTLSKeyPath, "metrics-tls-key", "", "local path to the TLS key") + return nil +} + +// Universal datastore flags only! +func RegisterDatastoreFlags(flags *pflag.FlagSet, cfg *server.Config, prefix string) error { + p := func(s string) string { + return stringz.Join("-", prefix, s) + } + + flags.IntVar(&cfg.MaxRelationshipContextSize, p("caveat-context-limit"), 25000, "max number of bytes for caveat context per relationship") + return cobrautil.MarkFlagsHidden(flags, p("caveat-context-limit")) +} + +func RegisterTelemetryFlags(flags *pflag.FlagSet, cfg *server.Config) error { + flags.StringVar(&cfg.TelemetryEndpoint, "telemetry-endpoint", telemetry.DefaultEndpoint, "endpoint to which telemetry is reported, empty string to disable") + flags.StringVar(&cfg.TelemetryCAOverridePath, "telemetry-ca", "", "local path to the certificate authority used to connect to telemetry") + flags.DurationVar(&cfg.TelemetryInterval, "telemetry-interval", telemetry.DefaultInterval, "approximate duration between telemetry reports (min 1m)") + + return cobrautil.MarkFlagsHidden( + flags, + "telemetry-endpoint", + "telemetry-ca", + "telemetry-interval", + ) +} + +func ParseCombinedGRPCURI(uri string) (network, uriWithoutNetwork string) { + before, after, found := strings.Cut(uri, "://") + if !found { + return uri, "tcp" + } + return strings.ToLower(before), after +} diff --git a/pkg/cmd/server/server.go b/pkg/cmd/server/server.go index 16d2449913..d0037566c9 100644 --- a/pkg/cmd/server/server.go +++ b/pkg/cmd/server/server.go @@ -43,6 +43,7 @@ import ( datastorecfg "github.com/authzed/spicedb/pkg/cmd/datastore" "github.com/authzed/spicedb/pkg/cmd/util" "github.com/authzed/spicedb/pkg/datastore" + grpcservutil "github.com/authzed/spicedb/pkg/grpcutil" "github.com/authzed/spicedb/pkg/middleware/requestid" "github.com/authzed/spicedb/pkg/spiceerrors" ) @@ -546,7 +547,7 @@ func (c *Config) initializeGateway(ctx context.Context) (util.RunnableHTTPServer } // If the requested network is a buffered one, then disable the HTTPGateway. - if c.GRPCServer.Network == util.BufferedNetwork { + if c.GRPCServer.Network == grpcservutil.BufferedNetwork { c.HTTPGateway.HTTPEnabled = false gatewayServer, err := c.HTTPGateway.Complete(zerolog.InfoLevel, nil) if err != nil { diff --git a/pkg/cmd/testing.go b/pkg/cmd/testing.go index 8fd6d583f2..4446289162 100644 --- a/pkg/cmd/testing.go +++ b/pkg/cmd/testing.go @@ -37,6 +37,7 @@ func NewTestingCommand(programName string, config *testserver.Config) *cobra.Com Use: "serve-testing", Short: "test server with an in-memory datastore", Long: "An in-memory spicedb server which serves completely isolated datastores per client-supplied auth token used.", + GroupID: "devtools", PreRunE: server.DefaultPreRunE(programName), RunE: termination.PublishError(func(cmd *cobra.Command, args []string) error { signalctx := SignalContextWithGracePeriod( diff --git a/pkg/cmd/util/util.go b/pkg/cmd/util/util.go index 6fa1c5c35d..7082faf229 100644 --- a/pkg/cmd/util/util.go +++ b/pkg/cmd/util/util.go @@ -5,7 +5,6 @@ package util import ( "context" "crypto/tls" - "crypto/x509" "errors" "fmt" "net" @@ -18,9 +17,7 @@ import ( "github.com/spf13/pflag" "google.golang.org/grpc" "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" - "google.golang.org/grpc/test/bufconn" // Register Snappy S2 compression _ "github.com/mostynb/go-grpc-compression/experimental/s2" @@ -29,13 +26,10 @@ import ( // Register cert watcher metrics _ "sigs.k8s.io/controller-runtime/pkg/certwatcher/metrics" - "github.com/authzed/spicedb/internal/grpchelpers" log "github.com/authzed/spicedb/internal/logging" - "github.com/authzed/spicedb/pkg/x509util" + "github.com/authzed/spicedb/pkg/grpcutil" ) -const BufferedNetwork string = "buffnet" - type GRPCServerConfig struct { Address string `debugmap:"visible"` Network string `debugmap:"visible"` @@ -71,35 +65,28 @@ func RegisterGRPCServerFlags(flags *pflag.FlagSet, config *GRPCServerConfig, fla flags.Uint32Var(&config.MaxWorkers, flagPrefix+"-max-workers", 0, "set the number of workers for this server (0 value means 1 worker per request)") } -type ( - DialFunc func(ctx context.Context, opts ...grpc.DialOption) (*grpc.ClientConn, error) - NetDialFunc func(ctx context.Context, s string) (net.Conn, error) -) - // Complete takes a set of default options and returns a completed server func (c *GRPCServerConfig) Complete(level zerolog.Level, svcRegistrationFn func(server *grpc.Server), opts ...grpc.ServerOption) (RunnableGRPCServer, error) { if !c.Enabled { return &disabledGrpcServer{}, nil } - if c.BufferSize == 0 { - c.BufferSize = 1024 * 1024 - } + opts = append(opts, grpc.KeepaliveParams(keepalive.ServerParameters{ MaxConnectionAge: c.MaxConnAge, }), grpc.NumStreamWorkers(c.MaxWorkers)) - tlsOpts, certWatcher, err := c.tlsOpts() + creds, certWatcher, err := grpcutil.TLSServerCreds(c.TLSCertPath, c.TLSKeyPath) if err != nil { return nil, err } - opts = append(opts, tlsOpts...) + opts = append(opts, grpc.Creds(creds)) - clientCreds, err := c.clientCreds() + clientCreds, err := grpcutil.TLSClientCreds(c.ClientCAPath, c.TLSCertPath, c.TLSKeyPath) if err != nil { return nil, err } - l, dial, netDial, err := c.listenerAndDialer() + l, dial, netDial, err := grpcutil.ListenerDialers(c.BufferSize, c.Network, c.Address) if err != nil { return nil, fmt.Errorf("failed to listen on addr for gRPC server: %w", err) } @@ -135,69 +122,6 @@ func (c *GRPCServerConfig) Complete(level zerolog.Level, svcRegistrationFn func( }, nil } -func (c *GRPCServerConfig) listenerAndDialer() (net.Listener, DialFunc, NetDialFunc, error) { - if c.Network == BufferedNetwork { - bl := bufconn.Listen(c.BufferSize) - return bl, func(ctx context.Context, opts ...grpc.DialOption) (*grpc.ClientConn, error) { - opts = append(opts, grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) { - return bl.DialContext(ctx) - })) - - return grpchelpers.Dial(ctx, BufferedNetwork, opts...) - }, func(ctx context.Context, s string) (net.Conn, error) { - return bl.DialContext(ctx) - }, nil - } - l, err := net.Listen(c.Network, c.Address) - if err != nil { - return nil, nil, nil, err - } - return l, func(ctx context.Context, opts ...grpc.DialOption) (*grpc.ClientConn, error) { - return grpchelpers.Dial(ctx, c.Address, opts...) - }, nil, nil -} - -func (c *GRPCServerConfig) tlsOpts() ([]grpc.ServerOption, *certwatcher.CertWatcher, error) { - switch { - case c.TLSCertPath == "" && c.TLSKeyPath == "": - return nil, nil, nil - case c.TLSCertPath != "" && c.TLSKeyPath != "": - watcher, err := certwatcher.New(c.TLSCertPath, c.TLSKeyPath) - if err != nil { - return nil, nil, err - } - creds := credentials.NewTLS(&tls.Config{ - GetCertificate: watcher.GetCertificate, - MinVersion: tls.VersionTLS12, - }) - return []grpc.ServerOption{grpc.Creds(creds)}, watcher, nil - default: - return nil, nil, nil - } -} - -func (c *GRPCServerConfig) clientCreds() (credentials.TransportCredentials, error) { - switch { - case c.TLSCertPath == "" && c.TLSKeyPath == "": - return insecure.NewCredentials(), nil - case c.TLSCertPath != "" && c.TLSKeyPath != "": - var err error - var pool *x509.CertPool - if c.ClientCAPath != "" { - pool, err = x509util.CustomCertPool(c.ClientCAPath) - } else { - pool, err = x509.SystemCertPool() - } - if err != nil { - return nil, err - } - - return credentials.NewTLS(&tls.Config{RootCAs: pool, MinVersion: tls.VersionTLS12}), nil - default: - return nil, nil - } -} - type RunnableGRPCServer interface { WithOpts(opts ...grpc.ServerOption) RunnableGRPCServer Listen(ctx context.Context) func() error diff --git a/pkg/grpcutil/prometheus.go b/pkg/grpcutil/prometheus.go new file mode 100644 index 0000000000..e6597b5ca9 --- /dev/null +++ b/pkg/grpcutil/prometheus.go @@ -0,0 +1,47 @@ +package grpcutil + +import ( + "context" + "sync" + + grpcprom "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus" + "github.com/prometheus/client_golang/prometheus" + "go.opentelemetry.io/otel/trace" + "google.golang.org/grpc" +) + +var ( + serverMetricsOnce sync.Once + serverMetrics *grpcprom.ServerMetrics +) + +func exemplarFromContextFunc(ctx context.Context) prometheus.Labels { + if span := trace.SpanContextFromContext(ctx); span.IsSampled() { + return prometheus.Labels{"traceID": span.TraceID().String()} + } + return nil +} + +func srvMetrics() *grpcprom.ServerMetrics { + serverMetricsOnce.Do(func() { + serverMetrics = grpcprom.NewServerMetrics([]grpcprom.ServerMetricsOption{ + grpcprom.WithServerHandlingTimeHistogram( + grpcprom.WithHistogramBuckets([]float64{.001, .003, .006, .010, .018, .024, .032, .042, .056, .075, .100, .178, .316, .562, 1, 5}), + ), + }...) + + // Deliberately ignore if these metrics were already registered, so that + // these metrics can be optionally registered with custom labels. + _ = prometheus.Register(serverMetrics) + }) + + return serverMetrics +} + +func PrometheusUnaryInterceptor() grpc.UnaryServerInterceptor { + return srvMetrics().UnaryServerInterceptor(grpcprom.WithExemplarFromContext(exemplarFromContextFunc)) +} + +func PrometheusStreamInterceptor() grpc.StreamServerInterceptor { + return srvMetrics().StreamServerInterceptor(grpcprom.WithExemplarFromContext(exemplarFromContextFunc)) +} diff --git a/pkg/grpcutil/tls.go b/pkg/grpcutil/tls.go new file mode 100644 index 0000000000..67f36f82da --- /dev/null +++ b/pkg/grpcutil/tls.go @@ -0,0 +1,113 @@ +package grpcutil + +import ( + "context" + "crypto/tls" + "crypto/x509" + "net" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" + "sigs.k8s.io/controller-runtime/pkg/certwatcher" + + "github.com/authzed/spicedb/pkg/x509util" +) + +// BufferedNetwork is a gRPC network that operates purely in memory. +const BufferedNetwork string = "buffnet" + +type ( + DialFunc func(context.Context, ...grpc.DialOption) (*grpc.ClientConn, error) + NetDialFunc func(ctx context.Context, addr string) (net.Conn, error) +) + +// NewBuffNet creates functions for binding a gRPC server to an +// in-memory buffer connection. +// +// This type of connection is useful for tests or in-process communication. +func NewBuffNet(bufferSize int) (net.Listener, DialFunc, NetDialFunc, error) { + if bufferSize == 0 { + bufferSize = 1024 * 1024 + } + + l := bufconn.Listen(bufferSize) + return l, func(ctx context.Context, opts ...grpc.DialOption) (*grpc.ClientConn, error) { + opts = append(opts, grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) { + return l.DialContext(ctx) + })) + + return grpc.DialContext(ctx, BufferedNetwork, opts...) + }, func(ctx context.Context, s string) (net.Conn, error) { + return l.DialContext(ctx) + }, nil +} + +// ListenerDialers returns functions for binding a gRPC server to the network +// and creating clients to that server. +// +// This function includes support for buffer connections. +func ListenerDialers(bufferSize int, network, addr string) (net.Listener, DialFunc, NetDialFunc, error) { + if network == BufferedNetwork { + return NewBuffNet(bufferSize) + } + l, err := net.Listen(network, addr) + if err != nil { + return nil, nil, nil, err + } + return l, func(ctx context.Context, opts ...grpc.DialOption) (*grpc.ClientConn, error) { + return grpc.DialContext(ctx, addr, opts...) + }, nil, nil +} + +// TLSServerCreds constructs TransportCredentials for a gRPC server using the +// provided filepaths. +// +// Certificates are watched and reloaded upon change. +func TLSServerCreds(certPath, keyPath string) (credentials.TransportCredentials, *certwatcher.CertWatcher, error) { + switch { + case certPath != "" && keyPath != "": + watcher, err := certwatcher.New(certPath, keyPath) + if err != nil { + return nil, nil, err + } + return credentials.NewTLS(&tls.Config{ + GetCertificate: watcher.GetCertificate, + MinVersion: tls.VersionTLS12, + }), watcher, nil + default: + return nil, nil, nil + } +} + +// TLSClientCreds constructs TransportCredentials for a gRPC connection with +// the provided filepaths. +// +// If the caPath is not provided, the system certifcate pool will be used. +// If both certPath and keyPath are not provided, an insecure transport is +// returned. +func TLSClientCreds(caPath, certPath, keyPath string) (credentials.TransportCredentials, error) { + switch { + case certPath == "" && keyPath == "": + return insecure.NewCredentials(), nil + case certPath != "" && keyPath != "": + var err error + var pool *x509.CertPool + if caPath != "" { + pool, err = x509util.CustomCertPool(caPath) + } else { + pool, err = x509.SystemCertPool() + } + if err != nil { + return nil, err + } + + return credentials.NewTLS(&tls.Config{ + RootCAs: pool, + MinVersion: tls.VersionTLS12, + }), nil + default: + return nil, nil + } +} diff --git a/pkg/grpcutil/zerolog.go b/pkg/grpcutil/zerolog.go new file mode 100644 index 0000000000..0882afe190 --- /dev/null +++ b/pkg/grpcutil/zerolog.go @@ -0,0 +1,66 @@ +package grpcutil + +import ( + "context" + "time" + + grpclog "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" + "github.com/rs/zerolog" + "go.opentelemetry.io/otel/trace" + "google.golang.org/grpc" +) + +var ( + durationFieldOption = grpclog.WithDurationField(func(duration time.Duration) grpclog.Fields { + return grpclog.Fields{"grpc.time_ms", duration.Milliseconds()} + }) + + traceIDFieldOption = grpclog.WithFieldsFromContext(func(ctx context.Context) grpclog.Fields { + if span := trace.SpanContextFromContext(ctx); span.IsSampled() { + return grpclog.Fields{"traceID", span.TraceID().String()} + } + return nil + }) +) + +// ZerologUnaryInterceptor maps gRPC logging to Zerolog according to the +// provided code-mapping function. +func ZerologUnaryInterceptor(l zerolog.Logger, mappingfn grpclog.CodeToLevel) grpc.UnaryServerInterceptor { + return grpclog.UnaryServerInterceptor( + zerologger(l), + grpclog.WithLevels(mappingfn), + durationFieldOption, + traceIDFieldOption, + ) +} + +// ZerologStreamInterceptor maps gRPC logging to Zerolog according to the +// provided code-mapping function. +func ZerologStreamInterceptor(l zerolog.Logger, mappingfn grpclog.CodeToLevel) grpc.StreamServerInterceptor { + return grpclog.StreamServerInterceptor( + zerologger(l), + grpclog.WithLevels(mappingfn), + durationFieldOption, + traceIDFieldOption, + ) +} + +func zerologger(l zerolog.Logger) grpclog.Logger { + return grpclog.LoggerFunc(func(ctx context.Context, lvl grpclog.Level, msg string, fields ...any) { + l := l.With().Fields(fields).Logger() + + switch lvl { + case grpclog.LevelDebug: + l.Debug().Msg(msg) + case grpclog.LevelInfo: + l.Info().Msg(msg) + case grpclog.LevelWarn: + l.Warn().Msg(msg) + case grpclog.LevelError: + l.Error().Msg(msg) + default: + l.Error().Int("level", int(lvl)).Msg("unknown error level - falling back to info level") + l.Info().Msg(msg) + } + }) +} diff --git a/pkg/migrate/migrate.go b/pkg/migrate/migrate.go index b33b81b6cd..a84bd0e60d 100644 --- a/pkg/migrate/migrate.go +++ b/pkg/migrate/migrate.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "strings" + "time" log "github.com/authzed/spicedb/internal/logging" ) @@ -211,3 +212,25 @@ func collectMigrationsInRange[C any, T any](starting, through string, all map[st return found, nil } + +func RunMigration[D Driver[C, T], C any, T any]( + ctx context.Context, + driver D, + manager *Manager[D, C, T], + targetRevision string, + timeout time.Duration, + backfillBatchSize uint64, +) error { + log.Ctx(ctx).Info().Str("targetRevision", targetRevision).Msg("running migrations") + ctxWithBatch := context.WithValue(ctx, BackfillBatchSize, backfillBatchSize) + ctx, cancel := context.WithTimeout(ctxWithBatch, timeout) + defer cancel() + if err := manager.Run(ctx, driver, targetRevision, LiveRun); err != nil { + return fmt.Errorf("unable to migrate to `%s` revision: %w", targetRevision, err) + } + + if err := driver.Close(ctx); err != nil { + return fmt.Errorf("unable to close migration driver: %w", err) + } + return nil +}