Skip to content

Commit

Permalink
fix: automatic scale up for raft clusters (#4324)
Browse files Browse the repository at this point in the history
Adds the Raft control service to existing grpc servers, and uses it to
join the cluster if the shard is not an initial member.

Scaling down is not automatic yet. When we scale down, we will need to
do it manually via an endpoint I will add next.
  • Loading branch information
jvmakine authored Feb 6, 2025
1 parent 0037b8b commit b6a0844
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 52 deletions.
26 changes: 18 additions & 8 deletions backend/schemaservice/schemaservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"fmt"
"net/url"
"os/signal"
"syscall"

"connectrpc.com/connect"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -68,34 +70,42 @@ func Start(
logger.Debugf("Starting FTL schema service")

g, gctx := errgroup.WithContext(ctx)
gctx, cancel := signal.NotifyContext(gctx, syscall.SIGTERM)
defer cancel()

var rpcOpts []rpc.Option

var shard statemachine.Handle[struct{}, SchemaState, EventWrapper]
if config.Raft.DataDir == "" {
// in local dev mode, use an inmemory state machine
shard = statemachine.NewLocalHandle(newStateMachine(ctx))
} else {
clusterBuilder := raft.NewBuilder(&config.Raft)
schemaShard := raft.AddShard(ctx, clusterBuilder, 1, newStateMachine(ctx))
cluster := clusterBuilder.Build(ctx)
if err := cluster.Start(ctx); err != nil {
return fmt.Errorf("failed to start raft cluster: %w", err)
}
schemaShard := raft.AddShard(gctx, clusterBuilder, 1, newStateMachine(ctx))
cluster := clusterBuilder.Build(gctx)
shard = schemaShard

rpcOpts = append(rpcOpts, raft.RPCOption(cluster))
}

svc := New(ctx, shard, config)
logger.Debugf("Listening on %s", config.Bind)

g.Go(func() error {
return rpc.Serve(gctx, config.Bind,
rpc.GRPC(ftlv1connect.NewSchemaServiceHandler, svc),
rpc.PProf(),
append(rpcOpts,
rpc.GRPC(ftlv1connect.NewSchemaServiceHandler, svc),
rpc.PProf(),
)...,
)
})

err := g.Wait()
if err != nil {
return fmt.Errorf("failed to start schema service: %w", err)
if gctx.Err() == nil {
// startup failure if the context was not cancelled
return fmt.Errorf("failed to start schema service: %w", err)
}
}
return nil
}
Expand Down
2 changes: 2 additions & 0 deletions charts/ftl/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ schema:
value: "$(ORDINAL_NUMBER)"
- name: RAFT_LISTEN_ADDRESS
value: "$(MY_POD_IP):8992"
- name: RAFT_CONTROL_ADDRESS
value: "http://ftl-schema:8892"

revisionHistoryLimit: 0

Expand Down
74 changes: 42 additions & 32 deletions internal/raft/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type RaftConfig struct {
DataDir string `help:"Data directory" required:"" env:"RAFT_DATA_DIR"`
Address string `help:"Address to advertise to other nodes" required:"" env:"RAFT_ADDRESS"`
ListenAddress string `help:"Address to listen for incoming traffic. If empty, Address will be used." env:"RAFT_LISTEN_ADDRESS"`
ControlBind *url.URL `help:"Address to listen for control traffic. If empty, no control listener will be started."`
ControlAddress *url.URL `help:"Address to connect to the control server" env:"RAFT_CONTROL_ADDRESS"`
ShardReadyTimeout time.Duration `help:"Timeout for shard to be ready" default:"5s"`
Retry retry.RetryConfig `help:"Connection retry configuration" prefix:"retry-" embed:""`
ChangesInterval time.Duration `help:"Interval for changes to be checked" default:"10ms"`
Expand Down Expand Up @@ -270,6 +270,24 @@ func (s *ShardHandle[Q, R, E]) StateIter(ctx context.Context, query Q) (iter.Seq
return iterops.Dedup(channels.IterContext(ctx, result)), nil
}

func RPCOption(cluster *Cluster) rpc.Option {
return rpc.Options(
rpc.StartHook(func(ctx context.Context) error {
if err := cluster.Start(ctx); err != nil {
return fmt.Errorf("failed to start raft cluster: %w", err)
}
return nil
}),
rpc.ShutdownHook(func(ctx context.Context) error {
logger := log.FromContext(ctx)
logger.Debugf("stopping raft cluster")
cluster.Stop(ctx)
return nil
}),
rpc.GRPC(raftpbconnect.NewRaftServiceHandler, cluster),
)
}

func (s *ShardHandle[Q, R, E]) getLastIndex() (uint64, error) {
s.verifyReady()

Expand All @@ -292,10 +310,24 @@ func (s *ShardHandle[Q, R, E]) verifyReady() {

// Start the cluster. Blocks until the cluster instance is ready.
func (c *Cluster) Start(ctx context.Context) error {
logger := log.FromContext(ctx).Scope("raft")
if c.nh != nil {
panic("cluster already started")
}

isInitial := false
for _, member := range c.config.InitialMembers {
if member == c.config.Address {
isInitial = true
break
}
}

if !isInitial {
logger.Infof("joining cluster as a new member")
return c.Join(ctx, c.config.ControlAddress.String())
}
logger.Infof("joining cluster as an initial member")
return c.start(ctx, false)
}

Expand Down Expand Up @@ -331,6 +363,8 @@ func (c *Cluster) Join(ctx context.Context, controlAddress string) error {
}

func (c *Cluster) start(ctx context.Context, join bool) error {
logger := log.FromContext(ctx).Scope("raft")

// Create node host config
nhc := config.NodeHostConfig{
WALDir: c.config.DataDir,
Expand All @@ -356,19 +390,17 @@ func (c *Cluster) start(ctx context.Context, join bool) error {

// Wait for all shards to be ready
for shardID := range c.shards {
if err := c.waitReady(ctx, shardID); err != nil {
err := c.waitReady(ctx, shardID)
if err != nil {
return fmt.Errorf("failed to wait for shard %d to be ready on replica %d: %w", shardID, c.config.ReplicaID, err)
}
}
logger.Infof("All shards are ready")

ctx, cancel := context.WithCancelCause(context.WithoutCancel(ctx))
c.runningCtxCancel = cancel
c.runningCtx = ctx

if err := c.startControlServer(ctx); err != nil {
return err
}

return nil
}

Expand Down Expand Up @@ -398,40 +430,17 @@ func (c *Cluster) startShard(nh *dragonboat.NodeHost, shardID uint64, sm statema
return nil
}

func (c *Cluster) startControlServer(ctx context.Context) error {
logger := log.FromContext(ctx).Scope("raft")

if c.config.ControlBind == nil {
return nil
}

logger.Infof("Starting control server on %s", c.config.ControlBind.String())
go func() {
err := rpc.Serve(ctx, c.config.ControlBind,
rpc.GRPC(raftpbconnect.NewRaftServiceHandler, c),
rpc.PProf())
if err != nil && !errors.Is(err, context.Canceled) {
logger.Errorf(err, "error serving control listener")
}
logger.Infof("Control server stopped")
}()
return nil
}

// Stop the node host and all shards.
// After this call, all the shard handlers created with this cluster are invalid.
func (c *Cluster) Stop(ctx context.Context) {
logger := log.FromContext(ctx).Scope("raft")
if c.nh != nil {
logger := log.FromContext(ctx).Scope("raft")
logger.Infof("stopping replica %d", c.config.ReplicaID)

for shardID := range c.shards {
c.removeShardMember(ctx, shardID, c.config.ReplicaID)
}
c.runningCtxCancel(fmt.Errorf("stopping raft cluster: %w", context.Canceled))
c.nh.Close()
c.nh = nil
c.shards = nil
} else {
logger.Debugf("raft cluster already stopped")
}
}

Expand Down Expand Up @@ -547,5 +556,6 @@ func (c *Cluster) waitReady(ctx context.Context, shardID uint64) error {
}
break
}
logger.Debugf("Shard %d on replica %d is ready", shardID, c.config.ReplicaID)
return nil
}
1 change: 0 additions & 1 deletion internal/raft/cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ func testBuilder(t *testing.T, addresses []*net.TCPAddr, id uint64, address stri
return raft.NewBuilder(&raft.RaftConfig{
ReplicaID: id,
Address: address,
ControlBind: controlBind,
DataDir: t.TempDir(),
InitialMembers: members,
HeartbeatRTT: 1,
Expand Down
76 changes: 65 additions & 11 deletions internal/rpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package rpc
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/url"
Expand All @@ -17,6 +18,7 @@ import (
"golang.org/x/net/http2/h2c"

gaphttp "github.com/block/ftl/internal/http"
"github.com/block/ftl/internal/log"
)

const ShutdownGracePeriod = time.Second * 5
Expand All @@ -25,6 +27,8 @@ type serverOptions struct {
mux *http.ServeMux
reflectionPaths []string
healthCheck http.HandlerFunc
startHooks []func(ctx context.Context) error
shutdownHooks []func(ctx context.Context) error
}

type Option func(*serverOptions)
Expand Down Expand Up @@ -73,10 +77,35 @@ func HTTP(prefix string, handler http.Handler) Option {
}
}

// ShutdownHook is called when the server is shutting down.
func ShutdownHook(hook func(ctx context.Context) error) Option {
return func(so *serverOptions) {
so.shutdownHooks = append(so.shutdownHooks, hook)
}
}

// StartHook is called when the server is starting up.
func StartHook(hook func(ctx context.Context) error) Option {
return func(so *serverOptions) {
so.startHooks = append(so.startHooks, hook)
}
}

// Options is a convenience function for aggregating multiple options.
func Options(options ...Option) Option {
return func(so *serverOptions) {
for _, option := range options {
option(so)
}
}
}

type Server struct {
listen *url.URL
Bind *pubsub.Topic[*url.URL] // Will be updated with the actual bind address.
Server *http.Server
listen *url.URL
shutdownHooks []func(ctx context.Context) error
startHooks []func(ctx context.Context) error
Bind *pubsub.Topic[*url.URL] // Will be updated with the actual bind address.
Server *http.Server
}

func NewServer(ctx context.Context, listen *url.URL, options ...Option) (*Server, error) {
Expand Down Expand Up @@ -106,9 +135,11 @@ func NewServer(ctx context.Context, listen *url.URL, options ...Option) (*Server
}

return &Server{
listen: listen,
Bind: pubsub.New[*url.URL](),
Server: http1Server,
listen: listen,
shutdownHooks: opts.shutdownHooks,
startHooks: opts.startHooks,
Bind: pubsub.New[*url.URL](),
Server: http1Server,
}, nil
}

Expand All @@ -127,30 +158,53 @@ func (s *Server) Serve(ctx context.Context) error {

// Shutdown server on context cancellation.
tree.Go(func(ctx context.Context) error {
logger := log.FromContext(ctx)

<-ctx.Done()

ctx, cancel := context.WithTimeout(context.Background(), ShutdownGracePeriod)
defer cancel()
ctx = log.ContextWithLogger(ctx, logger)

err := s.Server.Shutdown(ctx)
if err == nil {
return nil
}
if errors.Is(err, context.Canceled) {
_ = s.Server.Close()
return err
}
return err

for i, hook := range s.shutdownHooks {
logger.Debugf("Running shutdown hook %d/%d", i+1, len(s.shutdownHooks))
if err := hook(ctx); err != nil {
logger.Errorf(err, "shutdown hook failed")
}
}

return nil
})

// Start server.
tree.Go(func(ctx context.Context) error {
logger := log.FromContext(ctx)
for i, hook := range s.startHooks {
logger.Debugf("Running start hook %d/%d", i+1, len(s.startHooks))
if err := hook(ctx); err != nil {
logger.Errorf(err, "start hook failed")
}
}

err = s.Server.Serve(listener)
if errors.Is(err, http.ErrServerClosed) {
return nil
}
return err
})

return tree.Wait()
err = tree.Wait()
if err != nil {
return fmt.Errorf("failed to start server: %w", err)
}

return nil
}

// Serve starts a HTTP and Connect gRPC server with sane defaults for FTL.
Expand Down

0 comments on commit b6a0844

Please sign in to comment.