diff --git a/cmd/serve.go b/cmd/serve.go index a1f6431ef..9a2bcc23a 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -3,6 +3,7 @@ package cmd import ( "github.com/formancehq/go-libs/v2/logging" "github.com/formancehq/ledger/internal/api/common" + "github.com/formancehq/ledger/internal/leadership" systemstore "github.com/formancehq/ledger/internal/storage/system" "net/http" "net/http/pprof" @@ -108,6 +109,7 @@ func NewServeCommand() *cobra.Command { }), bus.NewFxModule(), ballast.Module(serveConfiguration.ballastSize), + leadership.NewFXModule(), api.Module(api.Config{ Version: Version, Debug: service.IsDebug(cmd), @@ -122,15 +124,15 @@ func NewServeCommand() *cobra.Command { }), fx.Decorate(func( params struct { - fx.In + fx.In - Handler chi.Router - HealthController *health.HealthController - Logger logging.Logger + Handler chi.Router + HealthController *health.HealthController + Logger logging.Logger - MeterProvider *metric.MeterProvider `optional:"true"` - Exporter *otlpmetrics.InMemoryExporter `optional:"true"` - }, + MeterProvider *metric.MeterProvider `optional:"true"` + Exporter *otlpmetrics.InMemoryExporter `optional:"true"` + }, ) chi.Router { return assembleFinalRouter( service.IsDebug(cmd), diff --git a/internal/README.md b/internal/README.md index bdd4b1dcd..616e988bf 100644 --- a/internal/README.md +++ b/internal/README.md @@ -147,7 +147,7 @@ var Zero = big.NewInt(0) ``` -## func [ComputeIdempotencyHash]() +## func ComputeIdempotencyHash ```go func ComputeIdempotencyHash(inputs any) string @@ -156,7 +156,7 @@ func ComputeIdempotencyHash(inputs any) string -## type [Account]() +## type Account @@ -175,7 +175,7 @@ type Account struct { ``` -### func \(Account\) [GetAddress]() +### func \(Account\) GetAddress ```go func (a Account) GetAddress() string @@ -184,7 +184,7 @@ func (a Account) GetAddress() string -## type [AccountMetadata]() +## type AccountMetadata @@ -193,7 +193,7 @@ type AccountMetadata map[string]metadata.Metadata ``` -## type [AccountsVolumes]() +## type AccountsVolumes @@ -209,7 +209,7 @@ type AccountsVolumes struct { ``` -## type [AggregatedVolumes]() +## type AggregatedVolumes @@ -220,7 +220,7 @@ type AggregatedVolumes struct { ``` -## type [BalancesByAssets]() +## type BalancesByAssets @@ -229,7 +229,7 @@ type BalancesByAssets map[string]*big.Int ``` -## type [BalancesByAssetsByAccounts]() +## type BalancesByAssetsByAccounts @@ -238,7 +238,7 @@ type BalancesByAssetsByAccounts map[string]BalancesByAssets ``` -## type [Configuration]() +## type Configuration @@ -251,7 +251,7 @@ type Configuration struct { ``` -### func [NewDefaultConfiguration]() +### func NewDefaultConfiguration ```go func NewDefaultConfiguration() Configuration @@ -260,7 +260,7 @@ func NewDefaultConfiguration() Configuration -### func \(\*Configuration\) [SetDefaults]() +### func \(\*Configuration\) SetDefaults ```go func (c *Configuration) SetDefaults() @@ -269,7 +269,7 @@ func (c *Configuration) SetDefaults() -### func \(\*Configuration\) [Validate]() +### func \(\*Configuration\) Validate ```go func (c *Configuration) Validate() error @@ -278,7 +278,7 @@ func (c *Configuration) Validate() error -## type [CreatedTransaction]() +## type CreatedTransaction @@ -290,7 +290,7 @@ type CreatedTransaction struct { ``` -### func \(CreatedTransaction\) [GetMemento]() +### func \(CreatedTransaction\) GetMemento ```go func (p CreatedTransaction) GetMemento() any @@ -299,7 +299,7 @@ func (p CreatedTransaction) GetMemento() any -### func \(CreatedTransaction\) [Type]() +### func \(CreatedTransaction\) Type ```go func (p CreatedTransaction) Type() LogType @@ -308,7 +308,7 @@ func (p CreatedTransaction) Type() LogType -## type [DeletedMetadata]() +## type DeletedMetadata @@ -321,7 +321,7 @@ type DeletedMetadata struct { ``` -### func \(DeletedMetadata\) [Type]() +### func \(DeletedMetadata\) Type ```go func (s DeletedMetadata) Type() LogType @@ -330,7 +330,7 @@ func (s DeletedMetadata) Type() LogType -### func \(\*DeletedMetadata\) [UnmarshalJSON]() +### func \(\*DeletedMetadata\) UnmarshalJSON ```go func (s *DeletedMetadata) UnmarshalJSON(data []byte) error @@ -339,7 +339,7 @@ func (s *DeletedMetadata) UnmarshalJSON(data []byte) error -## type [ErrInvalidBucketName]() +## type ErrInvalidBucketName @@ -350,7 +350,7 @@ type ErrInvalidBucketName struct { ``` -### func \(ErrInvalidBucketName\) [Error]() +### func \(ErrInvalidBucketName\) Error ```go func (e ErrInvalidBucketName) Error() string @@ -359,7 +359,7 @@ func (e ErrInvalidBucketName) Error() string -### func \(ErrInvalidBucketName\) [Is]() +### func \(ErrInvalidBucketName\) Is ```go func (e ErrInvalidBucketName) Is(err error) bool @@ -368,7 +368,7 @@ func (e ErrInvalidBucketName) Is(err error) bool -## type [ErrInvalidLedgerName]() +## type ErrInvalidLedgerName @@ -379,7 +379,7 @@ type ErrInvalidLedgerName struct { ``` -### func \(ErrInvalidLedgerName\) [Error]() +### func \(ErrInvalidLedgerName\) Error ```go func (e ErrInvalidLedgerName) Error() string @@ -388,7 +388,7 @@ func (e ErrInvalidLedgerName) Error() string -### func \(ErrInvalidLedgerName\) [Is]() +### func \(ErrInvalidLedgerName\) Is ```go func (e ErrInvalidLedgerName) Is(err error) bool @@ -397,7 +397,7 @@ func (e ErrInvalidLedgerName) Is(err error) bool -## type [Ledger]() +## type Ledger @@ -413,7 +413,7 @@ type Ledger struct { ``` -### func [MustNewWithDefault]() +### func MustNewWithDefault ```go func MustNewWithDefault(name string) Ledger @@ -422,7 +422,7 @@ func MustNewWithDefault(name string) Ledger -### func [New]() +### func New ```go func New(name string, configuration Configuration) (*Ledger, error) @@ -431,7 +431,7 @@ func New(name string, configuration Configuration) (*Ledger, error) -### func [NewWithDefaults]() +### func NewWithDefaults ```go func NewWithDefaults(name string) (*Ledger, error) @@ -440,7 +440,7 @@ func NewWithDefaults(name string) (*Ledger, error) -### func \(Ledger\) [HasFeature]() +### func \(Ledger\) HasFeature ```go func (l Ledger) HasFeature(feature, value string) bool @@ -449,7 +449,7 @@ func (l Ledger) HasFeature(feature, value string) bool -### func \(Ledger\) [WithMetadata]() +### func \(Ledger\) WithMetadata ```go func (l Ledger) WithMetadata(m metadata.Metadata) Ledger @@ -458,7 +458,7 @@ func (l Ledger) WithMetadata(m metadata.Metadata) Ledger -## type [Log]() +## type Log Log represents atomic actions made on the ledger. @@ -479,7 +479,7 @@ type Log struct { ``` -### func [NewLog]() +### func NewLog ```go func NewLog(payload LogPayload) Log @@ -488,7 +488,7 @@ func NewLog(payload LogPayload) Log -### func \(Log\) [ChainLog]() +### func \(Log\) ChainLog ```go func (l Log) ChainLog(previous *Log) Log @@ -497,7 +497,7 @@ func (l Log) ChainLog(previous *Log) Log -### func \(\*Log\) [ComputeHash]() +### func \(\*Log\) ComputeHash ```go func (l *Log) ComputeHash(previous *Log) @@ -506,7 +506,7 @@ func (l *Log) ComputeHash(previous *Log) -### func \(\*Log\) [UnmarshalJSON]() +### func \(\*Log\) UnmarshalJSON ```go func (l *Log) UnmarshalJSON(data []byte) error @@ -515,7 +515,7 @@ func (l *Log) UnmarshalJSON(data []byte) error -### func \(Log\) [WithIdempotencyKey]() +### func \(Log\) WithIdempotencyKey ```go func (l Log) WithIdempotencyKey(key string) Log @@ -524,7 +524,7 @@ func (l Log) WithIdempotencyKey(key string) Log -## type [LogPayload]() +## type LogPayload @@ -535,7 +535,7 @@ type LogPayload interface { ``` -### func [HydrateLog]() +### func HydrateLog ```go func HydrateLog(_type LogType, data []byte) (LogPayload, error) @@ -544,7 +544,7 @@ func HydrateLog(_type LogType, data []byte) (LogPayload, error) -## type [LogType]() +## type LogType @@ -564,7 +564,7 @@ const ( ``` -### func [LogTypeFromString]() +### func LogTypeFromString ```go func LogTypeFromString(logType string) LogType @@ -573,7 +573,7 @@ func LogTypeFromString(logType string) LogType -### func \(LogType\) [MarshalJSON]() +### func \(LogType\) MarshalJSON ```go func (lt LogType) MarshalJSON() ([]byte, error) @@ -582,7 +582,7 @@ func (lt LogType) MarshalJSON() ([]byte, error) -### func \(\*LogType\) [Scan]() +### func \(\*LogType\) Scan ```go func (lt *LogType) Scan(src interface{}) error @@ -591,7 +591,7 @@ func (lt *LogType) Scan(src interface{}) error -### func \(LogType\) [String]() +### func \(LogType\) String ```go func (lt LogType) String() string @@ -600,7 +600,7 @@ func (lt LogType) String() string -### func \(\*LogType\) [UnmarshalJSON]() +### func \(\*LogType\) UnmarshalJSON ```go func (lt *LogType) UnmarshalJSON(data []byte) error @@ -609,7 +609,7 @@ func (lt *LogType) UnmarshalJSON(data []byte) error -### func \(LogType\) [Value]() +### func \(LogType\) Value ```go func (lt LogType) Value() (driver.Value, error) @@ -618,7 +618,7 @@ func (lt LogType) Value() (driver.Value, error) -## type [Memento]() +## type Memento @@ -629,7 +629,7 @@ type Memento interface { ``` -## type [Move]() +## type Move @@ -650,7 +650,7 @@ type Move struct { ``` -## type [Moves]() +## type Moves @@ -659,7 +659,7 @@ type Moves []*Move ``` -### func \(Moves\) [ComputePostCommitEffectiveVolumes]() +### func \(Moves\) ComputePostCommitEffectiveVolumes ```go func (m Moves) ComputePostCommitEffectiveVolumes() PostCommitVolumes @@ -668,7 +668,7 @@ func (m Moves) ComputePostCommitEffectiveVolumes() PostCommitVolumes -## type [PostCommitVolumes]() +## type PostCommitVolumes @@ -677,7 +677,7 @@ type PostCommitVolumes map[string]VolumesByAssets ``` -### func \(PostCommitVolumes\) [AddInput]() +### func \(PostCommitVolumes\) AddInput ```go func (a PostCommitVolumes) AddInput(account, asset string, input *big.Int) @@ -686,7 +686,7 @@ func (a PostCommitVolumes) AddInput(account, asset string, input *big.Int) -### func \(PostCommitVolumes\) [AddOutput]() +### func \(PostCommitVolumes\) AddOutput ```go func (a PostCommitVolumes) AddOutput(account, asset string, output *big.Int) @@ -695,7 +695,7 @@ func (a PostCommitVolumes) AddOutput(account, asset string, output *big.Int) -### func \(PostCommitVolumes\) [Copy]() +### func \(PostCommitVolumes\) Copy ```go func (a PostCommitVolumes) Copy() PostCommitVolumes @@ -704,7 +704,7 @@ func (a PostCommitVolumes) Copy() PostCommitVolumes -### func \(PostCommitVolumes\) [Merge]() +### func \(PostCommitVolumes\) Merge ```go func (a PostCommitVolumes) Merge(volumes PostCommitVolumes) PostCommitVolumes @@ -713,7 +713,7 @@ func (a PostCommitVolumes) Merge(volumes PostCommitVolumes) PostCommitVolumes -## type [Posting]() +## type Posting @@ -727,7 +727,7 @@ type Posting struct { ``` -### func [NewPosting]() +### func NewPosting ```go func NewPosting(source string, destination string, asset string, amount *big.Int) Posting @@ -736,7 +736,7 @@ func NewPosting(source string, destination string, asset string, amount *big.Int -## type [Postings]() +## type Postings @@ -745,7 +745,7 @@ type Postings []Posting ``` -### func \(Postings\) [Reverse]() +### func \(Postings\) Reverse ```go func (p Postings) Reverse() Postings @@ -754,7 +754,7 @@ func (p Postings) Reverse() Postings -### func \(Postings\) [Validate]() +### func \(Postings\) Validate ```go func (p Postings) Validate() (int, error) @@ -763,7 +763,7 @@ func (p Postings) Validate() (int, error) -## type [RevertedTransaction]() +## type RevertedTransaction @@ -775,7 +775,7 @@ type RevertedTransaction struct { ``` -### func \(RevertedTransaction\) [GetMemento]() +### func \(RevertedTransaction\) GetMemento ```go func (r RevertedTransaction) GetMemento() any @@ -784,7 +784,7 @@ func (r RevertedTransaction) GetMemento() any -### func \(RevertedTransaction\) [Type]() +### func \(RevertedTransaction\) Type ```go func (r RevertedTransaction) Type() LogType @@ -793,7 +793,7 @@ func (r RevertedTransaction) Type() LogType -## type [SavedMetadata]() +## type SavedMetadata @@ -806,7 +806,7 @@ type SavedMetadata struct { ``` -### func \(SavedMetadata\) [Type]() +### func \(SavedMetadata\) Type ```go func (s SavedMetadata) Type() LogType @@ -815,7 +815,7 @@ func (s SavedMetadata) Type() LogType -### func \(\*SavedMetadata\) [UnmarshalJSON]() +### func \(\*SavedMetadata\) UnmarshalJSON ```go func (s *SavedMetadata) UnmarshalJSON(data []byte) error @@ -824,7 +824,7 @@ func (s *SavedMetadata) UnmarshalJSON(data []byte) error -## type [Transaction]() +## type Transaction @@ -845,7 +845,7 @@ type Transaction struct { ``` -### func [NewTransaction]() +### func NewTransaction ```go func NewTransaction() Transaction @@ -854,7 +854,7 @@ func NewTransaction() Transaction -### func \(Transaction\) [InvolvedAccounts]() +### func \(Transaction\) InvolvedAccounts ```go func (tx Transaction) InvolvedAccounts() []string @@ -863,7 +863,7 @@ func (tx Transaction) InvolvedAccounts() []string -### func \(Transaction\) [InvolvedDestinations]() +### func \(Transaction\) InvolvedDestinations ```go func (tx Transaction) InvolvedDestinations() map[string][]string @@ -872,7 +872,7 @@ func (tx Transaction) InvolvedDestinations() map[string][]string -### func \(Transaction\) [IsReverted]() +### func \(Transaction\) IsReverted ```go func (tx Transaction) IsReverted() bool @@ -881,7 +881,7 @@ func (tx Transaction) IsReverted() bool -### func \(Transaction\) [JSONSchemaExtend]() +### func \(Transaction\) JSONSchemaExtend ```go func (Transaction) JSONSchemaExtend(schema *jsonschema.Schema) @@ -890,7 +890,7 @@ func (Transaction) JSONSchemaExtend(schema *jsonschema.Schema) -### func \(Transaction\) [MarshalJSON]() +### func \(Transaction\) MarshalJSON ```go func (tx Transaction) MarshalJSON() ([]byte, error) @@ -899,7 +899,7 @@ func (tx Transaction) MarshalJSON() ([]byte, error) -### func \(Transaction\) [Reverse]() +### func \(Transaction\) Reverse ```go func (tx Transaction) Reverse() Transaction @@ -908,7 +908,7 @@ func (tx Transaction) Reverse() Transaction -### func \(Transaction\) [VolumeUpdates]() +### func \(Transaction\) VolumeUpdates ```go func (tx Transaction) VolumeUpdates() []AccountsVolumes @@ -917,7 +917,7 @@ func (tx Transaction) VolumeUpdates() []AccountsVolumes -### func \(Transaction\) [WithInsertedAt]() +### func \(Transaction\) WithInsertedAt ```go func (tx Transaction) WithInsertedAt(date time.Time) Transaction @@ -926,7 +926,7 @@ func (tx Transaction) WithInsertedAt(date time.Time) Transaction -### func \(Transaction\) [WithMetadata]() +### func \(Transaction\) WithMetadata ```go func (tx Transaction) WithMetadata(m metadata.Metadata) Transaction @@ -935,7 +935,7 @@ func (tx Transaction) WithMetadata(m metadata.Metadata) Transaction -### func \(Transaction\) [WithPostCommitEffectiveVolumes]() +### func \(Transaction\) WithPostCommitEffectiveVolumes ```go func (tx Transaction) WithPostCommitEffectiveVolumes(volumes PostCommitVolumes) Transaction @@ -944,7 +944,7 @@ func (tx Transaction) WithPostCommitEffectiveVolumes(volumes PostCommitVolumes) -### func \(Transaction\) [WithPostings]() +### func \(Transaction\) WithPostings ```go func (tx Transaction) WithPostings(postings ...Posting) Transaction @@ -953,7 +953,7 @@ func (tx Transaction) WithPostings(postings ...Posting) Transaction -### func \(Transaction\) [WithReference]() +### func \(Transaction\) WithReference ```go func (tx Transaction) WithReference(ref string) Transaction @@ -962,7 +962,7 @@ func (tx Transaction) WithReference(ref string) Transaction -### func \(Transaction\) [WithRevertedAt]() +### func \(Transaction\) WithRevertedAt ```go func (tx Transaction) WithRevertedAt(timestamp time.Time) Transaction @@ -971,7 +971,7 @@ func (tx Transaction) WithRevertedAt(timestamp time.Time) Transaction -### func \(Transaction\) [WithTimestamp]() +### func \(Transaction\) WithTimestamp ```go func (tx Transaction) WithTimestamp(ts time.Time) Transaction @@ -980,7 +980,7 @@ func (tx Transaction) WithTimestamp(ts time.Time) Transaction -## type [TransactionData]() +## type TransactionData @@ -995,7 +995,7 @@ type TransactionData struct { ``` -### func [NewTransactionData]() +### func NewTransactionData ```go func NewTransactionData() TransactionData @@ -1004,7 +1004,7 @@ func NewTransactionData() TransactionData -### func \(TransactionData\) [WithPostings]() +### func \(TransactionData\) WithPostings ```go func (data TransactionData) WithPostings(postings ...Posting) TransactionData @@ -1013,7 +1013,7 @@ func (data TransactionData) WithPostings(postings ...Posting) TransactionData -## type [Transactions]() +## type Transactions @@ -1024,7 +1024,7 @@ type Transactions struct { ``` -## type [Volumes]() +## type Volumes @@ -1036,7 +1036,7 @@ type Volumes struct { ``` -### func [NewEmptyVolumes]() +### func NewEmptyVolumes ```go func NewEmptyVolumes() Volumes @@ -1045,7 +1045,7 @@ func NewEmptyVolumes() Volumes -### func [NewVolumesInt64]() +### func NewVolumesInt64 ```go func NewVolumesInt64(input, output int64) Volumes @@ -1054,7 +1054,7 @@ func NewVolumesInt64(input, output int64) Volumes -### func \(Volumes\) [Balance]() +### func \(Volumes\) Balance ```go func (v Volumes) Balance() *big.Int @@ -1063,7 +1063,7 @@ func (v Volumes) Balance() *big.Int -### func \(Volumes\) [Copy]() +### func \(Volumes\) Copy ```go func (v Volumes) Copy() Volumes @@ -1072,7 +1072,7 @@ func (v Volumes) Copy() Volumes -### func \(Volumes\) [JSONSchemaExtend]() +### func \(Volumes\) JSONSchemaExtend ```go func (Volumes) JSONSchemaExtend(schema *jsonschema.Schema) @@ -1081,7 +1081,7 @@ func (Volumes) JSONSchemaExtend(schema *jsonschema.Schema) -### func \(Volumes\) [MarshalJSON]() +### func \(Volumes\) MarshalJSON ```go func (v Volumes) MarshalJSON() ([]byte, error) @@ -1090,7 +1090,7 @@ func (v Volumes) MarshalJSON() ([]byte, error) -### func \(\*Volumes\) [Scan]() +### func \(\*Volumes\) Scan ```go func (v *Volumes) Scan(src interface{}) error @@ -1099,7 +1099,7 @@ func (v *Volumes) Scan(src interface{}) error -### func \(Volumes\) [Value]() +### func \(Volumes\) Value ```go func (v Volumes) Value() (driver.Value, error) @@ -1108,7 +1108,7 @@ func (v Volumes) Value() (driver.Value, error) -## type [VolumesByAssets]() +## type VolumesByAssets @@ -1117,7 +1117,7 @@ type VolumesByAssets map[string]Volumes ``` -### func \(VolumesByAssets\) [Balances]() +### func \(VolumesByAssets\) Balances ```go func (v VolumesByAssets) Balances() BalancesByAssets @@ -1126,7 +1126,7 @@ func (v VolumesByAssets) Balances() BalancesByAssets -## type [VolumesWithBalance]() +## type VolumesWithBalance @@ -1139,7 +1139,7 @@ type VolumesWithBalance struct { ``` -## type [VolumesWithBalanceByAssetByAccount]() +## type VolumesWithBalanceByAssetByAccount @@ -1152,7 +1152,7 @@ type VolumesWithBalanceByAssetByAccount struct { ``` -## type [VolumesWithBalanceByAssets]() +## type VolumesWithBalanceByAssets diff --git a/internal/leadership/leadership.go b/internal/leadership/leadership.go new file mode 100644 index 000000000..2cef3459c --- /dev/null +++ b/internal/leadership/leadership.go @@ -0,0 +1,80 @@ +package leadership + +import ( + "context" + "errors" + "fmt" + "github.com/formancehq/go-libs/v2/logging" + "time" +) + +type Leadership struct { + locker Locker + changes *Signal + logger logging.Logger + retryPeriod time.Duration +} + +func (l *Leadership) acquire(ctx context.Context) error { + + acquired, release, err := l.locker.Take(ctx) + if err != nil { + return fmt.Errorf("error acquiring lock: %w", err) + } + + if acquired { + l.changes.Signal(true) + l.logger.Info("leadership acquired") + <-ctx.Done() + l.logger.Info("leadership lost") + release() + l.changes.Signal(false) + return ctx.Err() + } else { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(l.retryPeriod): + } + } + + return nil +} + +func (l *Leadership) Run(ctx context.Context) { + for { + if err := l.acquire(ctx); err != nil { + if errors.Is(err, context.Canceled) { + return + } + l.logger.Errorf("error acquiring leadership: %s", err) + } + } +} + +func (l *Leadership) GetSignal() *Signal { + return l.changes +} + +func NewLeadership(locker Locker, logger logging.Logger, options ...Option) *Leadership { + l := &Leadership{ + locker: locker, + logger: logger, + changes: NewSignal(), + retryPeriod: 2 * time.Second, + } + + for _, option := range options { + option(l) + } + + return l +} + +type Option func(leadership *Leadership) + +func WithRetryPeriod(duration time.Duration) Option { + return func(leadership *Leadership) { + leadership.retryPeriod = duration + } +} diff --git a/internal/leadership/leadership_test.go b/internal/leadership/leadership_test.go new file mode 100644 index 000000000..bd37f9ea4 --- /dev/null +++ b/internal/leadership/leadership_test.go @@ -0,0 +1,74 @@ +package leadership + +import ( + "context" + "github.com/formancehq/go-libs/v2/logging" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "testing" + "time" +) + +func TestLeaderShip(t *testing.T) { + + t.Parallel() + + ctx := logging.TestingContext() + ctrl := gomock.NewController(t) + + const count = 10 + selectedInstance := 0 + + type instance struct { + locker Locker + leadership *Leadership + ctx context.Context + cancel func() + } + + instances := make([]instance, count) + for i := range count { + m := NewMockLocker(ctrl) + m.EXPECT(). + Take(gomock.Any()). + AnyTimes(). + DoAndReturn(func(ctx context.Context) (bool, func(), error) { + return i == selectedInstance, func() {}, nil + }) + + l := NewLeadership(m, logging.Testing(), WithRetryPeriod(10*time.Millisecond)) + + ctx, cancel := context.WithCancel(ctx) + + go l.Run(ctx) + + instances[i] = instance{ + locker: m, + leadership: l, + ctx: ctx, + cancel: cancel, + } + } + + for _, nextLeader := range []int{0, 2, 4, 8} { + selectedInstance = nextLeader + + leadershipSignal, release := instances[nextLeader].leadership.GetSignal().Listen() + select { + case acquired := <-leadershipSignal: + require.True(t, acquired, "instance %d should be leader", nextLeader) + case <-time.After(100 * time.Millisecond): + t.Fatal("signal should have been received") + } + + instances[nextLeader].cancel() + + select { + case acquired := <-leadershipSignal: + require.False(t, acquired, "instance %d should have lost the leadership", nextLeader) + case <-time.After(100 * time.Millisecond): + t.Fatal("signal should have been received") + } + release() + } +} diff --git a/internal/leadership/locker.go b/internal/leadership/locker.go new file mode 100644 index 000000000..e8fadd2ed --- /dev/null +++ b/internal/leadership/locker.go @@ -0,0 +1,50 @@ +package leadership + +import ( + "context" + "fmt" + "github.com/uptrace/bun" +) + +const leadershipAdvisoryLockKey = 123456789 + +//go:generate mockgen -write_source_comment=false -write_package_comment=false -source locker.go -destination locker_generated_test.go -package leadership . Locker +type Locker interface { + Take(ctx context.Context) (bool, func(), error) +} + +type defaultLocker struct { + db *bun.DB +} + +func (p *defaultLocker) Take(ctx context.Context) (bool, func(), error) { + conn, err := p.db.Conn(ctx) + if err != nil { + return false, nil, fmt.Errorf("error opening new connection: %w", err) + } + + ret := conn.QueryRowContext(ctx, "select pg_try_advisory_lock(?)", leadershipAdvisoryLockKey) + if ret.Err() != nil { + _ = conn.Close() + return false, nil, fmt.Errorf("error acquiring lock: %w", ret.Err()) + } + + var acquired bool + if err := ret.Scan(&acquired); err != nil { + _ = conn.Close() + panic(err) + } + + if !acquired { + _ = conn.Close() + return false, nil, nil + } + + return true, func() { + _ = conn.Close() + }, nil +} + +func NewDefaultLocker(db *bun.DB) Locker { + return &defaultLocker{db: db} +} diff --git a/internal/leadership/locker_generated_test.go b/internal/leadership/locker_generated_test.go new file mode 100644 index 000000000..51e3a6ae8 --- /dev/null +++ b/internal/leadership/locker_generated_test.go @@ -0,0 +1,55 @@ +// Code generated by MockGen. DO NOT EDIT. +// +// Generated by this command: +// +// mockgen -write_source_comment=false -write_package_comment=false -source locker.go -destination locker_generated_test.go -package leadership . Locker +// + +package leadership + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockLocker is a mock of Locker interface. +type MockLocker struct { + ctrl *gomock.Controller + recorder *MockLockerMockRecorder + isgomock struct{} +} + +// MockLockerMockRecorder is the mock recorder for MockLocker. +type MockLockerMockRecorder struct { + mock *MockLocker +} + +// NewMockLocker creates a new mock instance. +func NewMockLocker(ctrl *gomock.Controller) *MockLocker { + mock := &MockLocker{ctrl: ctrl} + mock.recorder = &MockLockerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLocker) EXPECT() *MockLockerMockRecorder { + return m.recorder +} + +// Take mocks base method. +func (m *MockLocker) Take(ctx context.Context) (bool, func(), error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Take", ctx) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(func()) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// Take indicates an expected call of Take. +func (mr *MockLockerMockRecorder) Take(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Take", reflect.TypeOf((*MockLocker)(nil).Take), ctx) +} diff --git a/internal/leadership/module.go b/internal/leadership/module.go new file mode 100644 index 000000000..104f45ccf --- /dev/null +++ b/internal/leadership/module.go @@ -0,0 +1,39 @@ +package leadership + +import ( + "context" + "go.uber.org/fx" +) + +func NewFXModule() fx.Option { + return fx.Options( + fx.Provide(NewLeadership), + fx.Provide(NewDefaultLocker), + fx.Invoke(func(lc fx.Lifecycle, runner *Leadership) { + var ( + cancel context.CancelFunc + stopped = make(chan struct{}) + ) + lc.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + ctx, cancel = context.WithCancel(context.WithoutCancel(ctx)) + go func() { + defer close(stopped) + runner.Run(ctx) + }() + + return nil + }, + OnStop: func(ctx context.Context) error { + cancel() + select { + case <-stopped: + return nil + case <-ctx.Done(): + return ctx.Err() + } + }, + }) + }), + ) +} diff --git a/internal/leadership/signal.go b/internal/leadership/signal.go new file mode 100644 index 000000000..3fe538436 --- /dev/null +++ b/internal/leadership/signal.go @@ -0,0 +1,88 @@ +package leadership + +import ( + "github.com/formancehq/go-libs/v2/pointer" + "sync" +) + +type listener struct { + channel chan bool +} + +type Signal struct { + mu *sync.Mutex + t *bool + + inner []listener + outer chan bool +} + +func (h *Signal) Actual() *bool { + h.mu.Lock() + defer h.mu.Unlock() + + if h.t == nil { + return nil + } + + return pointer.For(*h.t) +} + +func (h *Signal) Listen() (<-chan bool, func()) { + h.mu.Lock() + defer h.mu.Unlock() + + newChannel := make(chan bool, 1) + index := len(h.inner) + h.inner = append(h.inner, listener{ + channel: newChannel, + }) + if h.t != nil { + newChannel <- *h.t + } + + return newChannel, func() { + h.mu.Lock() + defer h.mu.Unlock() + + if index < len(h.inner)-1 { + h.inner = append(h.inner[:index], h.inner[index+1:]...) + } else { + h.inner = h.inner[:index] + } + } +} + +func (h *Signal) Signal(t bool) { + h.mu.Lock() + defer h.mu.Unlock() + + h.t = &t + + for _, inner := range h.inner { + inner.channel <- t + } +} + +func (h *Signal) Close() { + h.mu.Lock() + defer h.mu.Unlock() + + for _, inner := range h.inner { + close(inner.channel) + } +} + +func (h *Signal) CountListeners() int { + h.mu.Lock() + defer h.mu.Unlock() + + return len(h.inner) +} + +func NewSignal() *Signal { + return &Signal{ + outer: make(chan bool), + mu: &sync.Mutex{}, + } +}