diff --git a/vms/platformvm/state/staker.go b/vms/platformvm/state/staker.go index 37bc512e36cb..2488e4aff79e 100644 --- a/vms/platformvm/state/staker.go +++ b/vms/platformvm/state/staker.go @@ -83,7 +83,7 @@ func (s *Staker) Less(than *Staker) bool { return bytes.Compare(s.TxID[:], than.TxID[:]) == -1 } -func NewCurrentStaker(txID ids.ID, staker txs.Staker, potentialReward uint64) (*Staker, error) { +func NewCurrentStaker(txID ids.ID, staker txs.ScheduledStaker, potentialReward uint64) (*Staker, error) { publicKey, _, err := staker.PublicKey() if err != nil { return nil, err @@ -103,7 +103,7 @@ func NewCurrentStaker(txID ids.ID, staker txs.Staker, potentialReward uint64) (* }, nil } -func NewPendingStaker(txID ids.ID, staker txs.Staker) (*Staker, error) { +func NewPendingStaker(txID ids.ID, staker txs.ScheduledStaker) (*Staker, error) { publicKey, _, err := staker.PublicKey() if err != nil { return nil, err diff --git a/vms/platformvm/state/staker_test.go b/vms/platformvm/state/staker_test.go index a6faa4ab704f..9482e33b793a 100644 --- a/vms/platformvm/state/staker_test.go +++ b/vms/platformvm/state/staker_test.go @@ -148,7 +148,7 @@ func TestNewCurrentStaker(t *testing.T) { potentialReward := uint64(54321) currentPriority := txs.SubnetPermissionedValidatorCurrentPriority - stakerTx := txs.NewMockStaker(ctrl) + stakerTx := txs.NewMockScheduledStaker(ctrl) stakerTx.EXPECT().NodeID().Return(nodeID) stakerTx.EXPECT().PublicKey().Return(publicKey, true, nil) stakerTx.EXPECT().SubnetID().Return(subnetID) @@ -192,7 +192,7 @@ func TestNewPendingStaker(t *testing.T) { endTime := time.Now() pendingPriority := txs.SubnetPermissionedValidatorPendingPriority - stakerTx := txs.NewMockStaker(ctrl) + stakerTx := txs.NewMockScheduledStaker(ctrl) stakerTx.EXPECT().NodeID().Return(nodeID) stakerTx.EXPECT().PublicKey().Return(publicKey, true, nil) stakerTx.EXPECT().SubnetID().Return(subnetID) diff --git a/vms/platformvm/state/state.go b/vms/platformvm/state/state.go index 82ff0107762a..36a49428020c 100644 --- a/vms/platformvm/state/state.go +++ b/vms/platformvm/state/state.go @@ -1320,9 +1320,13 @@ func (s *state) syncGenesis(genesisBlk block.Block, genesis *genesis.Genesis) er // Persist primary network validator set at genesis for _, vdrTx := range genesis.Validators { - validatorTx, ok := vdrTx.Unsigned.(txs.ValidatorTx) + // We expect genesis validator txs to be either AddValidatorTx or + // AddPermissionlessValidatorTx. + // + // TODO: Enforce stricter type check + validatorTx, ok := vdrTx.Unsigned.(txs.ScheduledStaker) if !ok { - return fmt.Errorf("expected tx type txs.ValidatorTx but got %T", vdrTx.Unsigned) + return fmt.Errorf("expected a scheduled staker but got %T", vdrTx.Unsigned) } stakeAmount := validatorTx.Weight() @@ -1451,7 +1455,7 @@ func (s *state) loadCurrentValidators() error { return fmt.Errorf("failed loading validator transaction txID %s, %w", txID, err) } - stakerTx, ok := tx.Unsigned.(txs.Staker) + stakerTx, ok := tx.Unsigned.(txs.ScheduledStaker) if !ok { return fmt.Errorf("expected tx type txs.Staker but got %T", tx.Unsigned) } @@ -1492,7 +1496,7 @@ func (s *state) loadCurrentValidators() error { return err } - stakerTx, ok := tx.Unsigned.(txs.Staker) + stakerTx, ok := tx.Unsigned.(txs.ScheduledStaker) if !ok { return fmt.Errorf("expected tx type txs.Staker but got %T", tx.Unsigned) } @@ -1539,7 +1543,7 @@ func (s *state) loadCurrentValidators() error { return err } - stakerTx, ok := tx.Unsigned.(txs.Staker) + stakerTx, ok := tx.Unsigned.(txs.ScheduledStaker) if !ok { return fmt.Errorf("expected tx type txs.Staker but got %T", tx.Unsigned) } @@ -1596,7 +1600,7 @@ func (s *state) loadPendingValidators() error { return err } - stakerTx, ok := tx.Unsigned.(txs.Staker) + stakerTx, ok := tx.Unsigned.(txs.ScheduledStaker) if !ok { return fmt.Errorf("expected tx type txs.Staker but got %T", tx.Unsigned) } @@ -1631,7 +1635,7 @@ func (s *state) loadPendingValidators() error { return err } - stakerTx, ok := tx.Unsigned.(txs.Staker) + stakerTx, ok := tx.Unsigned.(txs.ScheduledStaker) if !ok { return fmt.Errorf("expected tx type txs.Staker but got %T", tx.Unsigned) } diff --git a/vms/platformvm/txs/add_delegator_tx.go b/vms/platformvm/txs/add_delegator_tx.go index 4f6fbe395b02..af328cc693bf 100644 --- a/vms/platformvm/txs/add_delegator_tx.go +++ b/vms/platformvm/txs/add_delegator_tx.go @@ -19,7 +19,8 @@ import ( ) var ( - _ DelegatorTx = (*AddDelegatorTx)(nil) + _ DelegatorTx = (*AddDelegatorTx)(nil) + _ ScheduledStaker = (*AddDelegatorTx)(nil) errDelegatorWeightMismatch = errors.New("delegator weight is not equal to total stake weight") errStakeMustBeAVAX = errors.New("stake must be AVAX") diff --git a/vms/platformvm/txs/add_permissionless_delegator_tx.go b/vms/platformvm/txs/add_permissionless_delegator_tx.go index 43db685d7629..346a80dd61f7 100644 --- a/vms/platformvm/txs/add_permissionless_delegator_tx.go +++ b/vms/platformvm/txs/add_permissionless_delegator_tx.go @@ -17,7 +17,10 @@ import ( "github.com/ava-labs/avalanchego/vms/secp256k1fx" ) -var _ DelegatorTx = (*AddPermissionlessDelegatorTx)(nil) +var ( + _ DelegatorTx = (*AddPermissionlessDelegatorTx)(nil) + _ ScheduledStaker = (*AddPermissionlessDelegatorTx)(nil) +) // AddPermissionlessDelegatorTx is an unsigned addPermissionlessDelegatorTx type AddPermissionlessDelegatorTx struct { diff --git a/vms/platformvm/txs/add_permissionless_validator_tx.go b/vms/platformvm/txs/add_permissionless_validator_tx.go index 8f313ae000b9..34b13129ade8 100644 --- a/vms/platformvm/txs/add_permissionless_validator_tx.go +++ b/vms/platformvm/txs/add_permissionless_validator_tx.go @@ -21,7 +21,8 @@ import ( ) var ( - _ ValidatorTx = (*AddPermissionlessValidatorTx)(nil) + _ ValidatorTx = (*AddPermissionlessValidatorTx)(nil) + _ ScheduledStaker = (*AddPermissionlessDelegatorTx)(nil) errEmptyNodeID = errors.New("validator nodeID cannot be empty") errNoStake = errors.New("no stake") diff --git a/vms/platformvm/txs/add_subnet_validator_tx.go b/vms/platformvm/txs/add_subnet_validator_tx.go index 0ac3474e1bd6..53fd43562c02 100644 --- a/vms/platformvm/txs/add_subnet_validator_tx.go +++ b/vms/platformvm/txs/add_subnet_validator_tx.go @@ -14,7 +14,8 @@ import ( ) var ( - _ StakerTx = (*AddSubnetValidatorTx)(nil) + _ StakerTx = (*AddSubnetValidatorTx)(nil) + _ ScheduledStaker = (*AddSubnetValidatorTx)(nil) errAddPrimaryNetworkValidator = errors.New("can't add primary network validator with AddSubnetValidatorTx") ) diff --git a/vms/platformvm/txs/add_validator_tx.go b/vms/platformvm/txs/add_validator_tx.go index be6a93c2e42b..a0e82dba7c77 100644 --- a/vms/platformvm/txs/add_validator_tx.go +++ b/vms/platformvm/txs/add_validator_tx.go @@ -19,7 +19,8 @@ import ( ) var ( - _ ValidatorTx = (*AddValidatorTx)(nil) + _ ValidatorTx = (*AddValidatorTx)(nil) + _ ScheduledStaker = (*AddValidatorTx)(nil) errTooManyShares = fmt.Errorf("a staker can only require at most %d shares from delegators", reward.PercentDenominator) ) diff --git a/vms/platformvm/txs/executor/standard_tx_executor.go b/vms/platformvm/txs/executor/standard_tx_executor.go index 083e4f4c75c7..b9930075875a 100644 --- a/vms/platformvm/txs/executor/standard_tx_executor.go +++ b/vms/platformvm/txs/executor/standard_tx_executor.go @@ -534,7 +534,7 @@ func (e *StandardTxExecutor) BaseTx(tx *txs.BaseTx) error { } // Creates the staker as defined in [stakerTx] and adds it to [e.State]. -func (e *StandardTxExecutor) putStaker(stakerTx txs.Staker) error { +func (e *StandardTxExecutor) putStaker(stakerTx txs.ScheduledStaker) error { txID := e.Tx.ID() staker, err := state.NewPendingStaker(txID, stakerTx) if err != nil { diff --git a/vms/platformvm/txs/mempool/mempool.go b/vms/platformvm/txs/mempool/mempool.go index ce0d6a96f071..bd219910bd9f 100644 --- a/vms/platformvm/txs/mempool/mempool.go +++ b/vms/platformvm/txs/mempool/mempool.go @@ -280,7 +280,7 @@ func (m *mempool) DropExpiredStakerTxs(minStartTime time.Time) []ids.ID { txIter := m.unissuedTxs.NewIterator() for txIter.Next() { tx := txIter.Value() - stakerTx, ok := tx.Unsigned.(txs.Staker) + stakerTx, ok := tx.Unsigned.(txs.ScheduledStaker) if !ok { continue } diff --git a/vms/platformvm/txs/mock_scheduled_staker.go b/vms/platformvm/txs/mock_scheduled_staker.go new file mode 100644 index 000000000000..ce1a22b4eda0 --- /dev/null +++ b/vms/platformvm/txs/mock_scheduled_staker.go @@ -0,0 +1,151 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ava-labs/avalanchego/vms/platformvm/txs (interfaces: ScheduledStaker) + +// Package txs is a generated GoMock package. +package txs + +import ( + reflect "reflect" + time "time" + + ids "github.com/ava-labs/avalanchego/ids" + bls "github.com/ava-labs/avalanchego/utils/crypto/bls" + gomock "go.uber.org/mock/gomock" +) + +// MockScheduledStaker is a mock of ScheduledStaker interface. +type MockScheduledStaker struct { + ctrl *gomock.Controller + recorder *MockScheduledStakerMockRecorder +} + +// MockScheduledStakerMockRecorder is the mock recorder for MockScheduledStaker. +type MockScheduledStakerMockRecorder struct { + mock *MockScheduledStaker +} + +// NewMockScheduledStaker creates a new mock instance. +func NewMockScheduledStaker(ctrl *gomock.Controller) *MockScheduledStaker { + mock := &MockScheduledStaker{ctrl: ctrl} + mock.recorder = &MockScheduledStakerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockScheduledStaker) EXPECT() *MockScheduledStakerMockRecorder { + return m.recorder +} + +// CurrentPriority mocks base method. +func (m *MockScheduledStaker) CurrentPriority() Priority { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CurrentPriority") + ret0, _ := ret[0].(Priority) + return ret0 +} + +// CurrentPriority indicates an expected call of CurrentPriority. +func (mr *MockScheduledStakerMockRecorder) CurrentPriority() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CurrentPriority", reflect.TypeOf((*MockScheduledStaker)(nil).CurrentPriority)) +} + +// EndTime mocks base method. +func (m *MockScheduledStaker) EndTime() time.Time { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EndTime") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// EndTime indicates an expected call of EndTime. +func (mr *MockScheduledStakerMockRecorder) EndTime() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EndTime", reflect.TypeOf((*MockScheduledStaker)(nil).EndTime)) +} + +// NodeID mocks base method. +func (m *MockScheduledStaker) NodeID() ids.NodeID { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NodeID") + ret0, _ := ret[0].(ids.NodeID) + return ret0 +} + +// NodeID indicates an expected call of NodeID. +func (mr *MockScheduledStakerMockRecorder) NodeID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeID", reflect.TypeOf((*MockScheduledStaker)(nil).NodeID)) +} + +// PendingPriority mocks base method. +func (m *MockScheduledStaker) PendingPriority() Priority { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PendingPriority") + ret0, _ := ret[0].(Priority) + return ret0 +} + +// PendingPriority indicates an expected call of PendingPriority. +func (mr *MockScheduledStakerMockRecorder) PendingPriority() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PendingPriority", reflect.TypeOf((*MockScheduledStaker)(nil).PendingPriority)) +} + +// PublicKey mocks base method. +func (m *MockScheduledStaker) PublicKey() (*bls.PublicKey, bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PublicKey") + ret0, _ := ret[0].(*bls.PublicKey) + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// PublicKey indicates an expected call of PublicKey. +func (mr *MockScheduledStakerMockRecorder) PublicKey() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublicKey", reflect.TypeOf((*MockScheduledStaker)(nil).PublicKey)) +} + +// StartTime mocks base method. +func (m *MockScheduledStaker) StartTime() time.Time { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StartTime") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// StartTime indicates an expected call of StartTime. +func (mr *MockScheduledStakerMockRecorder) StartTime() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartTime", reflect.TypeOf((*MockScheduledStaker)(nil).StartTime)) +} + +// SubnetID mocks base method. +func (m *MockScheduledStaker) SubnetID() ids.ID { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SubnetID") + ret0, _ := ret[0].(ids.ID) + return ret0 +} + +// SubnetID indicates an expected call of SubnetID. +func (mr *MockScheduledStakerMockRecorder) SubnetID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubnetID", reflect.TypeOf((*MockScheduledStaker)(nil).SubnetID)) +} + +// Weight mocks base method. +func (m *MockScheduledStaker) Weight() uint64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Weight") + ret0, _ := ret[0].(uint64) + return ret0 +} + +// Weight indicates an expected call of Weight. +func (mr *MockScheduledStakerMockRecorder) Weight() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Weight", reflect.TypeOf((*MockScheduledStaker)(nil).Weight)) +} diff --git a/vms/platformvm/txs/mock_staker.go b/vms/platformvm/txs/mock_staker.go index e01ca66cf9e3..f74c2534ca39 100644 --- a/vms/platformvm/txs/mock_staker.go +++ b/vms/platformvm/txs/mock_staker.go @@ -81,20 +81,6 @@ func (mr *MockStakerMockRecorder) NodeID() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeID", reflect.TypeOf((*MockStaker)(nil).NodeID)) } -// PendingPriority mocks base method. -func (m *MockStaker) PendingPriority() Priority { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PendingPriority") - ret0, _ := ret[0].(Priority) - return ret0 -} - -// PendingPriority indicates an expected call of PendingPriority. -func (mr *MockStakerMockRecorder) PendingPriority() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PendingPriority", reflect.TypeOf((*MockStaker)(nil).PendingPriority)) -} - // PublicKey mocks base method. func (m *MockStaker) PublicKey() (*bls.PublicKey, bool, error) { m.ctrl.T.Helper() @@ -111,20 +97,6 @@ func (mr *MockStakerMockRecorder) PublicKey() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublicKey", reflect.TypeOf((*MockStaker)(nil).PublicKey)) } -// StartTime mocks base method. -func (m *MockStaker) StartTime() time.Time { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "StartTime") - ret0, _ := ret[0].(time.Time) - return ret0 -} - -// StartTime indicates an expected call of StartTime. -func (mr *MockStakerMockRecorder) StartTime() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartTime", reflect.TypeOf((*MockStaker)(nil).StartTime)) -} - // SubnetID mocks base method. func (m *MockStaker) SubnetID() ids.ID { m.ctrl.T.Helper() diff --git a/vms/platformvm/txs/staker_tx.go b/vms/platformvm/txs/staker_tx.go index 049d3519375f..1cdcdcc337a8 100644 --- a/vms/platformvm/txs/staker_tx.go +++ b/vms/platformvm/txs/staker_tx.go @@ -48,9 +48,13 @@ type Staker interface { // PublicKey returns the BLS public key registered by this transaction. If // there was no key registered by this transaction, it will return false. PublicKey() (*bls.PublicKey, bool, error) - StartTime() time.Time EndTime() time.Time Weight() uint64 - PendingPriority() Priority CurrentPriority() Priority } + +type ScheduledStaker interface { + Staker + StartTime() time.Time + PendingPriority() Priority +} diff --git a/vms/platformvm/validator_set_property_test.go b/vms/platformvm/validator_set_property_test.go index 2ac0d4358d7d..189e220679ff 100644 --- a/vms/platformvm/validator_set_property_test.go +++ b/vms/platformvm/validator_set_property_test.go @@ -373,7 +373,6 @@ func addPrimaryValidatorWithoutBLSKey(vm *VM, data *validatorInputData) (*state. } func internalAddValidator(vm *VM, signedTx *txs.Tx) (*state.Staker, error) { - stakerTx := signedTx.Unsigned.(txs.StakerTx) if err := vm.Network.IssueTx(context.Background(), signedTx); err != nil { return nil, fmt.Errorf("could not add tx to mempool: %w", err) } @@ -393,6 +392,7 @@ func internalAddValidator(vm *VM, signedTx *txs.Tx) (*state.Staker, error) { } // move time ahead, promoting the validator to current + stakerTx := signedTx.Unsigned.(txs.ScheduledStaker) currentTime := stakerTx.StartTime() vm.clock.Set(currentTime) vm.state.SetTimestamp(currentTime)