Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: improve address validator structure and make seal explicit #2022

Merged
merged 5 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions app/keepers.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,12 +286,13 @@ func initEvmKeeper(appCodec codec.Codec, keys map[string]*sdk.KVStoreKey, keeper
func initNexusKeeper(appCodec codec.Codec, keys map[string]*sdk.KVStoreKey, keepers *keeperCache) *nexusKeeper.Keeper {
// setting validator will finalize all by sealing it
// no more validators can be added
addressValidator := nexusTypes.NewAddressValidator().
addressValidator := nexusTypes.NewAddressValidators().
cgorenflo marked this conversation as resolved.
Show resolved Hide resolved
AddAddressValidator(evmTypes.ModuleName, evmKeeper.NewAddressValidator()).
AddAddressValidator(axelarnetTypes.ModuleName, axelarnetKeeper.NewAddressValidator(getKeeper[axelarnetKeeper.Keeper](keepers)))
addressValidator.Seal()

nexusK := nexusKeeper.NewKeeper(appCodec, keys[nexusTypes.StoreKey], keepers.getSubspace(nexusTypes.ModuleName))
nexusK.SetAddressValidator(addressValidator)
nexusK.SetAddressValidators(addressValidator)

return &nexusK
}
Expand Down
8 changes: 4 additions & 4 deletions x/nexus/keeper/address.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ func (k Keeper) GetRecipient(ctx sdk.Context, depositAddress exported.CrossChain

// ValidateAddress validates the given cross chain address
func (k Keeper) ValidateAddress(ctx sdk.Context, address exported.CrossChainAddress) error {
validator := k.getAddressValidator().GetAddressValidator(address.Chain.Module)
if validator == nil {
return fmt.Errorf("unknown module for chain %s", address.Chain.String())
validate, err := k.getAddressValidator(address.Chain.Module)
if err != nil {
return err
}

if err := validator(ctx, address); err != nil {
if err := validate(ctx, address); err != nil {
return err
}

Expand Down
8 changes: 5 additions & 3 deletions x/nexus/keeper/genesis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ func setup() (sdk.Context, Keeper) {
},
}

router := types.NewAddressValidator()
router.AddAddressValidator(evmTypes.ModuleName, evmkeeper.NewAddressValidator()).
addressValidators := types.NewAddressValidators()
addressValidators.AddAddressValidator(evmTypes.ModuleName, evmkeeper.NewAddressValidator()).
AddAddressValidator(axelarnetTypes.ModuleName, axelarnetkeeper.NewAddressValidator(axelarnetK))
keeper.SetAddressValidator(router)

addressValidators.Seal()
keeper.SetAddressValidators(addressValidators)

return ctx, keeper
}
Expand Down
5 changes: 3 additions & 2 deletions x/nexus/keeper/grpc_query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,14 @@ func TestKeeper_TransfersForChain(t *testing.T) {
funcs.MustNoErr(k.RegisterAsset(ctx, evm.Ethereum, exported.NewAsset(axelarnet.NativeAsset, false), utils.MaxUint, time.Hour))
funcs.MustNoErr(k.RegisterAsset(ctx, axelarnet.Axelarnet, exported.NewAsset(axelarnet.NativeAsset, true), utils.MaxUint, time.Hour))

nexusRouter := types.NewAddressValidator().
addressValidators := types.NewAddressValidators().
AddAddressValidator("evm", func(sdk.Context, exported.CrossChainAddress) error {
return nil
}).AddAddressValidator("axelarnet", func(sdk.Context, exported.CrossChainAddress) error {
return nil
})
k.SetAddressValidator(nexusRouter)
addressValidators.Seal()
k.SetAddressValidators(addressValidators)

}).
When("there are some pending transfers", func() {
Expand Down
30 changes: 15 additions & 15 deletions x/nexus/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package keeper

import (
"fmt"
"github.com/axelarnetwork/axelar-core/x/nexus/exported"

"github.com/cosmos/cosmos-sdk/codec"
sdk "github.com/cosmos/cosmos-sdk/types"
Expand Down Expand Up @@ -42,8 +43,8 @@ type Keeper struct {
cdc codec.BinaryCodec
params params.Subspace

addressValidator types.AddressValidator
messageRouter types.MessageRouter
addressValidators *types.AddressValidators
messageRouter types.MessageRouter
}

// NewKeeper returns a new nexus keeper
Expand All @@ -68,26 +69,25 @@ func (k Keeper) GetParams(ctx sdk.Context) types.Params {
return p
}

// SetAddressValidator sets the nexus address validator. It will panic if called more than once
func (k *Keeper) SetAddressValidator(validator types.AddressValidator) {
if k.addressValidator != nil {
panic("validator already set")
// SetAddressValidators sets the nexus address validator. It will panic if called more than once
func (k *Keeper) SetAddressValidators(validators *types.AddressValidators) {
if !validators.IsSealed() {
panic("address validator must be sealed")
}

k.addressValidator = validator
if k.addressValidators != nil {
panic("address validator already set")
}

// In order to avoid invalid or non-deterministic behavior, we seal the validator immediately
// to prevent additionals handlers from being registered after the keeper is initialized.
k.addressValidator.Seal()
k.addressValidators = validators
}

// getAddressValidator returns the nexus address validator. If not set, it returns a sealed empty validator
func (k Keeper) getAddressValidator() types.AddressValidator {
if k.addressValidator == nil {
k.SetAddressValidator(types.NewAddressValidator())
func (k Keeper) getAddressValidator(module string) (exported.AddressValidator, error) {
if k.addressValidators == nil {
cgorenflo marked this conversation as resolved.
Show resolved Hide resolved
k.SetAddressValidators(types.NewAddressValidators())
}

return k.addressValidator
return k.addressValidators.GetAddressValidator(module)
}

func (k *Keeper) SetMessageRouter(router types.MessageRouter) {
Expand Down
12 changes: 7 additions & 5 deletions x/nexus/keeper/keeper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const maxAmount int64 = 100000000000

var k keeper.Keeper

func addressValidator() types.AddressValidator {
func addressValidators() *types.AddressValidators {
axelarnetK := &axelarnetmock.BaseKeeperMock{
GetCosmosChainByNameFunc: func(ctx sdk.Context, chain exported.ChainName) (axelarnetTypes.CosmosChain, bool) {
var prefix string
Expand All @@ -47,18 +47,20 @@ func addressValidator() types.AddressValidator {
},
}

router := types.NewAddressValidator()
router.AddAddressValidator(evmTypes.ModuleName, evmkeeper.NewAddressValidator()).
validators := types.NewAddressValidators()
validators.AddAddressValidator(evmTypes.ModuleName, evmkeeper.NewAddressValidator()).
AddAddressValidator(axelarnetTypes.ModuleName, axelarnetkeeper.NewAddressValidator(axelarnetK))

return router
validators.Seal()

return validators
}

func init() {
encCfg := app.MakeEncodingConfig()
subspace := params.NewSubspace(encCfg.Codec, encCfg.Amino, sdk.NewKVStoreKey("nexusKey"), sdk.NewKVStoreKey("tNexusKey"), "nexus")
k = keeper.NewKeeper(encCfg.Codec, sdk.NewKVStoreKey("nexus"), subspace)
k.SetAddressValidator(addressValidator())
k.SetAddressValidators(addressValidators())
}

func TestLinkAddress(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion x/nexus/keeper/transfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ func setup(cfg params.EncodingConfig) (nexusKeeper.Keeper, sdk.Context) {
ctx := sdk.NewContext(fake.NewMultiStore(), tmproto.Header{}, false, log.TestingLogger())

k.SetParams(ctx, types.DefaultParams())
k.SetAddressValidator(addressValidator())
k.SetAddressValidators(addressValidators())

// register asset in ChainState
for _, chain := range chains {
Expand Down
47 changes: 24 additions & 23 deletions x/nexus/types/address_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,36 @@ import (
"github.com/axelarnetwork/axelar-core/x/nexus/exported"
)

// AddressValidator implements a AddressValidator based on module name.
type AddressValidator interface {
AddAddressValidator(module string, validator exported.AddressValidator) AddressValidator
HasAddressValidator(module string) bool
GetAddressValidator(module string) exported.AddressValidator
Seal()
}

var _ AddressValidator = (*addressValidator)(nil)

type addressValidator struct {
// AddressValidators collects all registered address validators by module
type AddressValidators struct {
validators map[string]exported.AddressValidator
sealed bool
}

// NewAddressValidator creates a new AddressValidator interface instance
func NewAddressValidator() AddressValidator {
return &addressValidator{
// NewAddressValidators returns a new AddressValidators instance
func NewAddressValidators() *AddressValidators {
return &AddressValidators{
validators: make(map[string]exported.AddressValidator),
}
}

// Seal prevents additional validators from being added
func (r *addressValidator) Seal() {
func (r *AddressValidators) Seal() {
if r.sealed {
panic("cannot seal address validator (validator already sealed)")
}

r.sealed = true
}

// IsSealed returns true if the validator is sealed
func (r *AddressValidators) IsSealed() bool {
return r.sealed
}

// AddAddressValidator registers a validator for a given path
// panics if the validator is sealed, module is an empty string, or if the module has been registered already
func (r *addressValidator) AddAddressValidator(module string, validator exported.AddressValidator) AddressValidator {
func (r *AddressValidators) AddAddressValidator(module string, validator exported.AddressValidator) *AddressValidators {
if r.sealed {
panic("cannot add validator (validator sealed)")
}
Expand All @@ -53,15 +53,16 @@ func (r *addressValidator) AddAddressValidator(module string, validator exported
}

// HasAddressValidator returns true if a validator is registered for the given module
func (r *addressValidator) HasAddressValidator(module string) bool {
return r.validators[module] != nil
func (r *AddressValidators) HasAddressValidator(module string) bool {
cgorenflo marked this conversation as resolved.
Show resolved Hide resolved
_, err := r.GetAddressValidator(module)
return err == nil
}

// GetAddressValidator returns a validator for a given module
func (r *addressValidator) GetAddressValidator(module string) exported.AddressValidator {
if !r.HasAddressValidator(module) {
panic(fmt.Sprintf("validator for module \"%s\" not registered", module))
func (r *AddressValidators) GetAddressValidator(module string) (exported.AddressValidator, error) {
validate, ok := r.validators[module]
if !ok || validate == nil {
return nil, fmt.Errorf("validator for module \"%s\" not registered", module)
}

return r.validators[module]
return validate, nil
}