Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hotfix(exporter): ensure thread-safe committee observer management #1883

Open
wants to merge 8 commits into
base: stage
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 63 additions & 53 deletions operator/validator/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ type controller struct {
historySyncBatchSize int
messageValidator validation.MessageValidator

// nonCommittees is a cache of initialized committeeObserver instances
// committeeObservers is a cache of initialized committeeObserver instances
committeesObservers *ttlcache.Cache[spectypes.MessageID, *committeeObserver]
committeesObserversMutex sync.Mutex
attesterRoots *ttlcache.Cache[phase0.Root, struct{}]
Expand Down Expand Up @@ -281,7 +281,7 @@ func NewController(logger *zap.Logger, options ControllerOptions) Controller {
messageValidator: options.MessageValidator,
}

// Start automatic expired item deletion in nonCommitteeValidators.
// Start automatic expired item deletion in committeeObserverValidators.
go ctrl.committeesObservers.Start()
// Delete old root and domain entries.
go ctrl.attesterRoots.Start()
Expand Down Expand Up @@ -365,57 +365,74 @@ func (c *controller) handleRouterMessages() {
}
}

var nonCommitteeValidatorTTLs = map[spectypes.RunnerRole]int{
spectypes.RoleCommittee: 64,
spectypes.RoleProposer: 4,
spectypes.RoleAggregator: 4,
//spectypes.BNRoleSyncCommittee: 4,
var committeeObserverValidatorTTLs = map[spectypes.RunnerRole]int{
spectypes.RoleCommittee: 64,
spectypes.RoleProposer: 4,
spectypes.RoleAggregator: 4,
Comment on lines -353 to +356
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe clarify that's denominated in slots (and not seconds or something)

spectypes.RoleSyncCommitteeContribution: 4,
}

func (c *controller) handleWorkerMessages(msg network.DecodedSSVMessage) error {
var ncv *committeeObserver
ssvMsg := msg.(*queue.SSVMessage)

item := c.getNonCommitteeValidators(ssvMsg.GetID())
if item == nil {
committeeObserverOptions := validator.CommitteeObserverOptions{
Logger: c.logger,
NetworkConfig: c.networkConfig,
ValidatorStore: c.validatorStore,
Network: c.validatorOptions.Network,
Storage: c.validatorOptions.Storage,
FullNode: c.validatorOptions.FullNode,
Operator: c.validatorOptions.Operator,
OperatorSigner: c.validatorOptions.OperatorSigner,
NewDecidedHandler: c.validatorOptions.NewDecidedHandler,
AttesterRoots: c.attesterRoots,
SyncCommRoots: c.syncCommRoots,
DomainCache: c.domainCache,
}
ncv = &committeeObserver{
CommitteeObserver: validator.NewCommitteeObserver(convert.MessageID(ssvMsg.MsgID), committeeObserverOptions),
}
ttlSlots := nonCommitteeValidatorTTLs[ssvMsg.MsgID.GetRoleType()]
c.committeesObservers.Set(
ssvMsg.GetID(),
ncv,
time.Duration(ttlSlots)*c.beacon.GetBeaconNetwork().SlotDurationSec(),
)
} else {
ncv = item
}
if err := c.handleNonCommitteeMessages(ssvMsg, ncv); err != nil {
return err
observer := c.getCommitteeObserver(ssvMsg)

if err := c.handleCommitteeObserverMessage(ssvMsg, observer); err != nil {
return fmt.Errorf("failed to handle committee observer message: %w", err)
}

return nil
}

func (c *controller) handleNonCommitteeMessages(msg *queue.SSVMessage, ncv *committeeObserver) error {
func (c *controller) getCommitteeObserver(ssvMsg *queue.SSVMessage) *committeeObserver {
c.committeesObserversMutex.Lock()
defer c.committeesObserversMutex.Unlock()

if msg.MsgType == spectypes.SSVConsensusMsgType {
// Check if the observer already exists
existingObserver := c.committeesObservers.Get(ssvMsg.GetID())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't ttlcache already thread safe? why we need the mutex here then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the mutex, there’s a risk of race conditions where multiple goroutines might attempt to create and set the same observer simultaneously.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be also solved by the cache's GetOrSet func, but involves creating the observer each time and sometimes for nothing. though it will be GC'd quickly if not used, there is no goroutine in the creation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right that GetOrSet avoids race conditions without a mutex, but it can lead to unnecessary observer creation, even if quickly garbage collected. The current mutex ensures observers are only created when needed, minimizing overhead for non-trivial operations. If the cost of unused creation is negligible, switching to GetOrSet could simplify the code. Let me know if we’d like to explore this direction further.

if existingObserver != nil {
return existingObserver.Value()
}

// Create a new committee observer if it doesn't exist
committeeObserverOptions := validator.CommitteeObserverOptions{
Logger: c.logger,
NetworkConfig: c.networkConfig,
ValidatorStore: c.validatorStore,
Network: c.validatorOptions.Network,
Storage: c.validatorOptions.Storage,
FullNode: c.validatorOptions.FullNode,
Operator: c.validatorOptions.Operator,
OperatorSigner: c.validatorOptions.OperatorSigner,
NewDecidedHandler: c.validatorOptions.NewDecidedHandler,
AttesterRoots: c.attesterRoots,
SyncCommRoots: c.syncCommRoots,
DomainCache: c.domainCache,
}
newObserver := &committeeObserver{
CommitteeObserver: validator.NewCommitteeObserver(convert.MessageID(ssvMsg.MsgID), committeeObserverOptions),
}

c.committeesObservers.Set(
ssvMsg.GetID(),
newObserver,
c.calculateObserverTTL(ssvMsg.MsgID.GetRoleType()),
)

return newObserver
}

func (c *controller) calculateObserverTTL(roleType spectypes.RunnerRole) time.Duration {
ttlSlots := committeeObserverValidatorTTLs[roleType]
return time.Duration(ttlSlots) * c.beacon.GetBeaconNetwork().SlotDurationSec()
}
Comment on lines +410 to +413
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's probably better to "panic" than "default to 0" in case there is a new type introduced and we forgot to add it to committeeObserverValidatorTTLs map (so we can discover this issue early)


func (c *controller) handleCommitteeObserverMessage(msg *queue.SSVMessage, observer *committeeObserver) error {
observer.Lock()
defer observer.Unlock()

switch msg.GetType() {
case spectypes.SSVConsensusMsgType:
// Process proposal messages for committee consensus only to get the roots
if msg.MsgID.GetRoleType() != spectypes.RoleCommittee {
return nil
Expand All @@ -426,24 +443,17 @@ func (c *controller) handleNonCommitteeMessages(msg *queue.SSVMessage, ncv *comm
return nil
}

return ncv.OnProposalMsg(msg)
} else if msg.MsgType == spectypes.SSVPartialSignatureMsgType {
return observer.OnProposalMsg(msg)
case spectypes.SSVPartialSignatureMsgType:
pSigMessages := &spectypes.PartialSignatureMessages{}
if err := pSigMessages.Decode(msg.SignedSSVMessage.SSVMessage.GetData()); err != nil {
return err
return fmt.Errorf("failed to decode partial signature messages: %w", err)
}

return ncv.ProcessMessage(msg)
}
return nil
}

func (c *controller) getNonCommitteeValidators(messageId spectypes.MessageID) *committeeObserver {
item := c.committeesObservers.Get(messageId)
if item != nil {
return item.Value()
return observer.ProcessMessage(msg)
default:
return nil
}
return nil
}

// StartValidators loads all persisted shares and setup the corresponding validators
Expand Down
74 changes: 37 additions & 37 deletions protocol/v2/ssv/validator/non_committee_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func NewCommitteeObserver(identifier convert.MessageID, opts CommitteeObserverOp
// TODO: does the specific operator matters?

ctrl := qbftcontroller.NewController(identifier[:], opts.Operator, config, opts.OperatorSigner, opts.FullNode)
ctrl.StoredInstances = make(qbftcontroller.InstanceContainer, 0, nonCommitteeInstanceContainerCapacity(opts.FullNode))
ctrl.StoredInstances = make(qbftcontroller.InstanceContainer, 0, committeeObserverInstanceContainerCapacity(opts.FullNode))

return &CommitteeObserver{
qbftController: ctrl,
Expand All @@ -86,10 +86,10 @@ func NewCommitteeObserver(identifier convert.MessageID, opts CommitteeObserverOp
}
}

func (ncv *CommitteeObserver) ProcessMessage(msg *queue.SSVMessage) error {
func (o *CommitteeObserver) ProcessMessage(msg *queue.SSVMessage) error {
role := msg.MsgID.GetRoleType()

logger := ncv.logger.With(fields.Role(role))
logger := o.logger.With(fields.Role(role))
if role == spectypes.RoleCommittee {
cid := spectypes.CommitteeID(msg.GetID().GetDutyExecutorID()[16:])
logger = logger.With(fields.CommitteeID(cid))
Expand All @@ -113,7 +113,7 @@ func (ncv *CommitteeObserver) ProcessMessage(msg *queue.SSVMessage) error {
return fmt.Errorf("got invalid message %w", err)
}

quorums, err := ncv.processMessage(partialSigMessages)
quorums, err := o.processMessage(partialSigMessages)
if err != nil {
return fmt.Errorf("could not process SignedPartialSignatureMessage %w", err)
}
Expand All @@ -128,25 +128,25 @@ func (ncv *CommitteeObserver) ProcessMessage(msg *queue.SSVMessage) error {
operatorIDs = append(operatorIDs, strconv.FormatUint(share, 10))
}

validator, exists := ncv.ValidatorStore.ValidatorByIndex(key.ValidatorIndex)
validator, exists := o.ValidatorStore.ValidatorByIndex(key.ValidatorIndex)
if !exists {
return fmt.Errorf("could not find share for validator with index %d", key.ValidatorIndex)
}

beaconRoles := ncv.getBeaconRoles(msg, key.Root)
beaconRoles := o.getBeaconRoles(msg, key.Root)
if len(beaconRoles) == 0 {
logger.Warn("no roles found for quorum root",
zap.Uint64("validator_index", uint64(key.ValidatorIndex)),
fields.Validator(validator.ValidatorPubKey[:]),
zap.String("signers", strings.Join(operatorIDs, ", ")),
fields.BlockRoot(key.Root),
zap.String("qbft_ctrl_identifier", hex.EncodeToString(ncv.qbftController.Identifier)),
zap.String("qbft_ctrl_identifier", hex.EncodeToString(o.qbftController.Identifier)),
)
}

for _, beaconRole := range beaconRoles {
msgID := convert.NewMsgID(ncv.qbftController.GetConfig().GetSignatureDomainType(), validator.ValidatorPubKey[:], beaconRole)
roleStorage := ncv.Storage.Get(msgID.GetRoleType())
msgID := convert.NewMsgID(o.qbftController.GetConfig().GetSignatureDomainType(), validator.ValidatorPubKey[:], beaconRole)
roleStorage := o.Storage.Get(msgID.GetRoleType())
if roleStorage == nil {
return fmt.Errorf("role storage doesn't exist: %v", beaconRole)
}
Expand All @@ -169,8 +169,8 @@ func (ncv *CommitteeObserver) ProcessMessage(msg *queue.SSVMessage) error {
fields.BlockRoot(key.Root),
)

if ncv.newDecidedHandler != nil {
ncv.newDecidedHandler(qbftstorage.ParticipantsRangeEntry{
if o.newDecidedHandler != nil {
o.newDecidedHandler(qbftstorage.ParticipantsRangeEntry{
Slot: slot,
Signers: quorum,
Identifier: msgID,
Expand All @@ -182,10 +182,10 @@ func (ncv *CommitteeObserver) ProcessMessage(msg *queue.SSVMessage) error {
return nil
}

func (ncv *CommitteeObserver) getBeaconRoles(msg *queue.SSVMessage, root phase0.Root) []convert.RunnerRole {
func (o *CommitteeObserver) getBeaconRoles(msg *queue.SSVMessage, root phase0.Root) []convert.RunnerRole {
if msg.MsgID.GetRoleType() == spectypes.RoleCommittee {
attester := ncv.attesterRoots.Get(root)
syncCommittee := ncv.syncCommRoots.Get(root)
attester := o.attesterRoots.Get(root)
syncCommittee := o.syncCommRoots.Get(root)

switch {
case attester != nil && syncCommittee != nil:
Expand All @@ -201,8 +201,8 @@ func (ncv *CommitteeObserver) getBeaconRoles(msg *queue.SSVMessage, root phase0.
return []convert.RunnerRole{casts.RunnerRoleToConvertRole(msg.MsgID.GetRoleType())}
}

// nonCommitteeInstanceContainerCapacity returns the capacity of InstanceContainer for non-committee validators
func nonCommitteeInstanceContainerCapacity(fullNode bool) int {
// committeeObserverInstanceContainerCapacity returns the capacity of InstanceContainer for committee observer validators
func committeeObserverInstanceContainerCapacity(fullNode bool) int {
if fullNode {
// Helps full nodes reduce
return 2
Expand All @@ -215,23 +215,23 @@ type validatorIndexAndRoot struct {
Root phase0.Root
}

func (ncv *CommitteeObserver) processMessage(
func (o *CommitteeObserver) processMessage(
signedMsg *spectypes.PartialSignatureMessages,
) (map[validatorIndexAndRoot][]spectypes.OperatorID, error) {
quorums := make(map[validatorIndexAndRoot][]spectypes.OperatorID)

for _, msg := range signedMsg.Messages {
validator, exists := ncv.ValidatorStore.ValidatorByIndex(msg.ValidatorIndex)
validator, exists := o.ValidatorStore.ValidatorByIndex(msg.ValidatorIndex)
if !exists {
return nil, fmt.Errorf("could not find share for validator with index %d", msg.ValidatorIndex)
}
container, ok := ncv.postConsensusContainer[msg.ValidatorIndex]
container, ok := o.postConsensusContainer[msg.ValidatorIndex]
if !ok {
container = ssv.NewPartialSigContainer(validator.Quorum())
ncv.postConsensusContainer[msg.ValidatorIndex] = container
o.postConsensusContainer[msg.ValidatorIndex] = container
}
if container.HasSignature(msg.ValidatorIndex, msg.Signer, msg.SigningRoot) {
ncv.resolveDuplicateSignature(container, msg, validator)
o.resolveDuplicateSignature(container, msg, validator)
} else {
container.AddSignature(msg)
}
Expand All @@ -255,11 +255,11 @@ func (ncv *CommitteeObserver) processMessage(

// Stores the container's existing signature or the new one, depending on their validity. If both are invalid, remove the existing one
// copied from BaseRunner
func (ncv *CommitteeObserver) resolveDuplicateSignature(container *ssv.PartialSigContainer, msg *spectypes.PartialSignatureMessage, share *ssvtypes.SSVShare) {
func (o *CommitteeObserver) resolveDuplicateSignature(container *ssv.PartialSigContainer, msg *spectypes.PartialSignatureMessage, share *ssvtypes.SSVShare) {
// Check previous signature validity
previousSignature, err := container.GetSignature(msg.ValidatorIndex, msg.Signer, msg.SigningRoot)
if err == nil {
err = ncv.verifyBeaconPartialSignature(msg.Signer, previousSignature, msg.SigningRoot, share)
err = o.verifyBeaconPartialSignature(msg.Signer, previousSignature, msg.SigningRoot, share)
if err == nil {
// Keep the previous sigature since it's correct
return
Expand All @@ -270,14 +270,14 @@ func (ncv *CommitteeObserver) resolveDuplicateSignature(container *ssv.PartialSi
container.Remove(msg.ValidatorIndex, msg.Signer, msg.SigningRoot)

// Hold the new signature, if correct
err = ncv.verifyBeaconPartialSignature(msg.Signer, msg.PartialSignature, msg.SigningRoot, share)
err = o.verifyBeaconPartialSignature(msg.Signer, msg.PartialSignature, msg.SigningRoot, share)
if err == nil {
container.AddSignature(msg)
}
}

// copied from BaseRunner
func (ncv *CommitteeObserver) verifyBeaconPartialSignature(signer uint64, signature spectypes.Signature, root phase0.Root, share *ssvtypes.SSVShare) error {
func (o *CommitteeObserver) verifyBeaconPartialSignature(signer uint64, signature spectypes.Signature, root phase0.Root, share *ssvtypes.SSVShare) error {
ssvtypes.MetricsSignaturesVerifications.WithLabelValues().Inc()

for _, n := range share.Committee {
Expand All @@ -301,33 +301,33 @@ func (ncv *CommitteeObserver) verifyBeaconPartialSignature(signer uint64, signat
return fmt.Errorf("unknown signer")
}

func (ncv *CommitteeObserver) OnProposalMsg(msg *queue.SSVMessage) error {
func (o *CommitteeObserver) OnProposalMsg(msg *queue.SSVMessage) error {
beaconVote := &spectypes.BeaconVote{}
if err := beaconVote.Decode(msg.SignedSSVMessage.FullData); err != nil {
ncv.logger.Debug("❗ failed to get beacon vote data", zap.Error(err))
o.logger.Debug("❗ failed to get beacon vote data", zap.Error(err))
return err
}

qbftMsg, ok := msg.Body.(*specqbft.Message)
if !ok {
ncv.logger.Fatal("unreachable: OnProposalMsg must be called only on qbft messages")
o.logger.Fatal("unreachable: OnProposalMsg must be called only on qbft messages")
}

epoch := ncv.beaconNetwork.EstimatedEpochAtSlot(phase0.Slot(qbftMsg.Height))
epoch := o.beaconNetwork.EstimatedEpochAtSlot(phase0.Slot(qbftMsg.Height))

if err := ncv.saveAttesterRoots(epoch, beaconVote, qbftMsg); err != nil {
if err := o.saveAttesterRoots(epoch, beaconVote, qbftMsg); err != nil {
return err
}

if err := ncv.saveSyncCommRoots(epoch, beaconVote); err != nil {
if err := o.saveSyncCommRoots(epoch, beaconVote); err != nil {
return err
}

return nil
}

func (ncv *CommitteeObserver) saveAttesterRoots(epoch phase0.Epoch, beaconVote *spectypes.BeaconVote, qbftMsg *specqbft.Message) error {
attesterDomain, err := ncv.domainCache.Get(epoch, spectypes.DomainAttester)
func (o *CommitteeObserver) saveAttesterRoots(epoch phase0.Epoch, beaconVote *spectypes.BeaconVote, qbftMsg *specqbft.Message) error {
attesterDomain, err := o.domainCache.Get(epoch, spectypes.DomainAttester)
if err != nil {
return err
}
Expand All @@ -339,14 +339,14 @@ func (ncv *CommitteeObserver) saveAttesterRoots(epoch phase0.Epoch, beaconVote *
return err
}

ncv.attesterRoots.Set(attesterRoot, struct{}{}, ttlcache.DefaultTTL)
o.attesterRoots.Set(attesterRoot, struct{}{}, ttlcache.DefaultTTL)
}

return nil
}

func (ncv *CommitteeObserver) saveSyncCommRoots(epoch phase0.Epoch, beaconVote *spectypes.BeaconVote) error {
syncCommDomain, err := ncv.domainCache.Get(epoch, spectypes.DomainSyncCommittee)
func (o *CommitteeObserver) saveSyncCommRoots(epoch phase0.Epoch, beaconVote *spectypes.BeaconVote) error {
syncCommDomain, err := o.domainCache.Get(epoch, spectypes.DomainSyncCommittee)
if err != nil {
return err
}
Expand All @@ -357,7 +357,7 @@ func (ncv *CommitteeObserver) saveSyncCommRoots(epoch phase0.Epoch, beaconVote *
return err
}

ncv.syncCommRoots.Set(syncCommitteeRoot, struct{}{}, ttlcache.DefaultTTL)
o.syncCommRoots.Set(syncCommitteeRoot, struct{}{}, ttlcache.DefaultTTL)

return nil
}
Expand Down
Loading