Skip to content

Commit

Permalink
Fix bug in tick scheduler command
Browse files Browse the repository at this point in the history
Refactor worker pool client handling and improve logging
  • Loading branch information
billettc committed Feb 6, 2025
1 parent cab50b3 commit 13eef93
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 57 deletions.
7 changes: 2 additions & 5 deletions app/tier2.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@ import (
"github.com/streamingfast/dmetrics"
"github.com/streamingfast/shutter"
"github.com/streamingfast/substreams/metrics"
"github.com/streamingfast/substreams/orchestrator/work"
"github.com/streamingfast/substreams/pipeline"
"github.com/streamingfast/substreams/service"
"github.com/streamingfast/substreams/wasm"
"github.com/streamingfast/substreams/wasm/wazero"
"github.com/streamingfast/worker-pool-protocol/pb/sf/worker/v1/pbworkerconnect"
pbworker "github.com/streamingfast/worker-pool-protocol/pb/sf/worker/v1"
"go.uber.org/atomic"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -44,8 +43,7 @@ type Tier2App struct {

type Tier2Modules struct {
CheckPendingShutDown func() bool
RemoteWorkerClient pbworkerconnect.WorkerPoolClient
WorkerPoolFactory work.WorkerPoolFactory
RemoteWorkerClient pbworker.WorkerPoolClient
}

func NewTier2(logger *zap.Logger, config *Tier2Config, modules *Tier2Modules) *Tier2App {
Expand Down Expand Up @@ -95,7 +93,6 @@ func (a *Tier2App) Run() error {

svc, err := service.NewTier2(
a.modules.RemoteWorkerClient,
a.modules.WorkerPoolFactory,
a.logger,
opts...,
)
Expand Down
4 changes: 2 additions & 2 deletions orchestrator/loop/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ func Quit(err error) Cmd {
}
}

func Tick(delay time.Duration, fn func() Msg) Cmd {
func Tick(delay time.Duration, msg Msg) Cmd {
return func() Msg {
time.Sleep(delay)
return fn()
return msg
}
}
9 changes: 6 additions & 3 deletions orchestrator/scheduler/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type Scheduler struct {

func New(ctx context.Context, stream *response.Stream) *Scheduler {
logger := reqctx.Logger(ctx)
logger = logger.Named("scheduler")
s := &Scheduler{
ctx: ctx,
stream: stream,
Expand Down Expand Up @@ -80,6 +81,10 @@ func (s *Scheduler) Update(msg loop.Msg) loop.Cmd {
var cmds []loop.Cmd

switch msg := msg.(type) {
case loop.Cmd:
c := msg()
err := fmt.Sprintf("receive loop cmd instead of %T\n", c)
panic(err)
case work.MsgJobSucceeded:

shadowedUnits := s.Stages.MarkJobSuccess(msg.Unit)
Expand All @@ -105,7 +110,6 @@ func (s *Scheduler) Update(msg loop.Msg) loop.Cmd {
if s.ExecOutWalker != nil {
cmds = append(cmds, execout.CmdDownloadSegment(0))
}

case work.MsgScheduleNextJob:
worker, err := s.WorkerPool.Borrow(s.ctx)
if err != nil {
Expand All @@ -114,8 +118,7 @@ func (s *Scheduler) Update(msg loop.Msg) loop.Cmd {
} else {
s.logger.Error("scheduler: failed to borrow worker", zap.Error(err))
}
cmds = append(cmds, loop.Tick(time.Second, func() loop.Msg { return work.CmdScheduleNextJob() }))
break
return loop.Tick(1*time.Second, work.MsgScheduleNextJob{})
}
workUnit, workRange := s.Stages.NextJob()
if workRange == nil { // End of job
Expand Down
6 changes: 3 additions & 3 deletions orchestrator/work/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@ import (
tracing "github.com/streamingfast/sf-tracing"
"github.com/streamingfast/substreams/client"
"github.com/streamingfast/substreams/reqctx"
"github.com/streamingfast/worker-pool-protocol/pb/sf/worker/v1/pbworkerconnect"
pbworker "github.com/streamingfast/worker-pool-protocol/pb/sf/worker/v1"
)

type WorkerPoolFactory func(ctx context.Context) WorkerPool

type GlobalWorkerPoolFactory struct {
clientFactory client.InternalClientFactory
remoteWorkerPool pbworkerconnect.WorkerPoolClient
remoteWorkerPool pbworker.WorkerPoolClient
workerKeepAliveDelay time.Duration
}

func NewGlobalWorkerPoolFactory(remoteWorkerPool pbworkerconnect.WorkerPoolClient, clientFactory client.InternalClientFactory, workerKeepAliveDelay time.Duration) *GlobalWorkerPoolFactory {
func NewGlobalWorkerPoolFactory(remoteWorkerPool pbworker.WorkerPoolClient, clientFactory client.InternalClientFactory, workerKeepAliveDelay time.Duration) *GlobalWorkerPoolFactory {
return &GlobalWorkerPoolFactory{
remoteWorkerPool: remoteWorkerPool,
workerKeepAliveDelay: workerKeepAliveDelay,
Expand Down
63 changes: 31 additions & 32 deletions orchestrator/work/globalworkerpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@ import (
"fmt"
"time"

"connectrpc.com/connect"
"github.com/streamingfast/substreams/client"
"github.com/streamingfast/substreams/reqctx"
pbworker "github.com/streamingfast/worker-pool-protocol/pb/sf/worker/v1"
"github.com/streamingfast/worker-pool-protocol/pb/sf/worker/v1/pbworkerconnect"
"go.uber.org/zap"
)

Expand All @@ -20,70 +18,71 @@ type GlobalWorkerPool struct {
startedAt time.Time
firstWorkerServed bool

remoteWorkerPool pbworkerconnect.WorkerPoolClient
logger *zap.Logger
clientFactory client.InternalClientFactory
workerKeepAliveDelay time.Duration
maxWorkerForTraceID uint64
remoteWorkerPoolClient pbworker.WorkerPoolClient
logger *zap.Logger
clientFactory client.InternalClientFactory
workerKeepAliveDelay time.Duration
maxWorkerForTraceID uint64
}

func NewGlobalWorkerPool(ctx context.Context, userID string, traceID string, maxWorkerForTraceID uint64, remoteWorkerPool pbworkerconnect.WorkerPoolClient, clientFactory client.InternalClientFactory, workerKeepAliveDelay time.Duration) *GlobalWorkerPool {
func NewGlobalWorkerPool(ctx context.Context, userID string, traceID string, maxWorkerForTraceID uint64, remoteWorkerPoolClient pbworker.WorkerPoolClient, clientFactory client.InternalClientFactory, workerKeepAliveDelay time.Duration) *GlobalWorkerPool {
logger := reqctx.Logger(ctx)
logger = logger.Named("global-worker-pool")

logger.Debug("initializing worker pool", zap.String("user_id", userID), zap.String("trace_id", traceID))
logger.Info("initializing worker pool", zap.String("user_id", userID), zap.String("trace_id", traceID))

return &GlobalWorkerPool{
userID: userID,
traceID: traceID,
maxWorkerForTraceID: maxWorkerForTraceID,
remoteWorkerPool: remoteWorkerPool,
startedAt: time.Now(),
clientFactory: clientFactory,
workerKeepAliveDelay: workerKeepAliveDelay,
logger: logger,
userID: userID,
traceID: traceID,
maxWorkerForTraceID: maxWorkerForTraceID,
remoteWorkerPoolClient: remoteWorkerPoolClient,
startedAt: time.Now(),
clientFactory: clientFactory,
workerKeepAliveDelay: workerKeepAliveDelay,
logger: logger,
}
}

var ErrorResourceExhausted = errors.New("resource exhausted")

func (p *GlobalWorkerPool) Borrow(ctx context.Context) (Worker, error) {
rampUpCompleted := time.Since(p.startedAt) < time.Second*4
if !rampUpCompleted && p.firstWorkerServed {
rampingUp := time.Since(p.startedAt) < time.Second*4
if rampingUp && p.firstWorkerServed {
p.logger.Info("worker pool is exhausted because of ramp up", zap.Bool("first_worker_served", p.firstWorkerServed), zap.Bool("ramping_up", rampingUp), zap.Duration("time_since_start", time.Since(p.startedAt)))
return nil, ErrorResourceExhausted
}

response, err := p.remoteWorkerPool.BorrowWorker(ctx,
&connect.Request[pbworker.BorrowWorkerRequest]{
Msg: &pbworker.BorrowWorkerRequest{
UserId: p.userID,
TraceId: p.traceID,
MaxWorkerForTraceId: int64(p.maxWorkerForTraceID),
},
response, err := p.remoteWorkerPoolClient.BorrowWorker(ctx,
&pbworker.BorrowWorkerRequest{
UserId: p.userID,
TraceId: p.traceID,
MaxWorkerForTraceId: int64(p.maxWorkerForTraceID),
},
)

if err != nil {
return nil, fmt.Errorf("borrowing worker for user %q and trace %q: %w", p.userID, p.traceID, err)
}

if response.Msg.Status == pbworker.BorrowWorkerResponse_borrowed {
if response.Status == pbworker.BorrowWorkerResponse_resource_exhausted {
p.logger.Info("worker pool is exhausted", zap.String("worker_key", response.WorkerKey), zap.String("status", response.Status.String()))
return nil, ErrorResourceExhausted
}

p.firstWorkerServed = true
worker := NewRemoteWorker(p.clientFactory, response.Msg.WorkerKey, p.workerKeepAliveDelay, p.logger)
worker := NewRemoteWorker(p.clientFactory, response.WorkerKey, p.workerKeepAliveDelay, p.logger)
return worker, nil
}

func (p *GlobalWorkerPool) Return(ctx context.Context, worker Worker) {
key := worker.ID()
_, err := p.remoteWorkerPool.ReturnWorker(ctx, &connect.Request[pbworker.ReturnWorkerRequest]{
Msg: &pbworker.ReturnWorkerRequest{
_, err := p.remoteWorkerPoolClient.ReturnWorker(ctx,
&pbworker.ReturnWorkerRequest{
WorkerKey: key,
},
})
})

if err != nil {
p.logger.Error("returning worker", zap.Error(err))
}
p.logger.Info("returning worker", zap.String("worker_key", key))
}
1 change: 1 addition & 0 deletions orchestrator/work/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ type RemoteWorker struct {
}

func NewRemoteWorker(clientFactory client.InternalClientFactory, id string, keepAliveDelay time.Duration, logger *zap.Logger) *RemoteWorker {
logger = logger.Named("remote-worker")
return &RemoteWorker{
clientFactory: clientFactory,
tracer: otel.GetTracerProvider().Tracer("worker"),
Expand Down
21 changes: 9 additions & 12 deletions service/tier2.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import (
"github.com/streamingfast/substreams/storage/store"
"github.com/streamingfast/substreams/wasm"
pbworker "github.com/streamingfast/worker-pool-protocol/pb/sf/worker/v1"
"github.com/streamingfast/worker-pool-protocol/pb/sf/worker/v1/pbworkerconnect"
"go.opentelemetry.io/otel/attribute"
ttrace "go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
Expand Down Expand Up @@ -81,22 +80,19 @@ type Tier2Service struct {
blockExecutionTimeout time.Duration

tier2RequestParameters *reqctx.Tier2RequestParameters
remoteWorkerClient pbworkerconnect.WorkerPoolClient
workerPoolFactory work.WorkerPoolFactory
remoteWorkerClient pbworker.WorkerPoolClient
}

const protoPkfPrefix = "type.googleapis.com/"

func NewTier2(
remoteWorkerClient pbworkerconnect.WorkerPoolClient,
workerPoolFactory work.WorkerPoolFactory,
remoteWorkerClient pbworker.WorkerPoolClient,
logger *zap.Logger,
opts ...Option,
) (*Tier2Service, error) {

s := &Tier2Service{
remoteWorkerClient: remoteWorkerClient,
workerPoolFactory: workerPoolFactory,
tracer: tracing.GetTracer(),
logger: logger,
blockExecutionTimeout: 3 * time.Minute,
Expand Down Expand Up @@ -348,7 +344,7 @@ func (s *Tier2Service) processRange(ctx context.Context, request *pbssinternal.P
wasmRegistry,
execOutputCacheEngine,
request.SegmentSize,
s.workerPoolFactory,
nil,
respFunc,
s.blockExecutionTimeout,
opts...,
Expand Down Expand Up @@ -456,8 +452,9 @@ excludable:
}

done := make(chan struct{})
if s.remoteWorkerClient != nil && reflect.ValueOf(s.remoteWorkerClient).IsNil() {
if s.remoteWorkerClient != nil && !reflect.ValueOf(s.remoteWorkerClient).IsNil() {
workerID, keepAliveDelay, err := work.IncomingParameters(ctx)
s.logger.Info("got remote worker client, setting up keep alive", zap.String("worker_id", workerID), zap.Duration("keep_alive_delay", keepAliveDelay))
if err != nil {
return fmt.Errorf("getting incoming parameters: %w", err)
}
Expand All @@ -466,11 +463,11 @@ excludable:
case <-ctx.Done():
return
case <-time.After(keepAliveDelay):
_, err := s.remoteWorkerClient.KeepAlive(ctx, &connect.Request[pbworker.KeepAliveRequest]{
Msg: &pbworker.KeepAliveRequest{
s.logger.Info("keep alive timer expired, calling keep alive")
_, err := s.remoteWorkerClient.KeepAlive(ctx,
&pbworker.KeepAliveRequest{
WorkerKey: workerID,
},
})
})
if err != nil {
s.logger.Error("failed to call keep alive", zap.String("worker_id", workerID), zap.Error(err))
}
Expand Down

0 comments on commit 13eef93

Please sign in to comment.