Skip to content

Commit

Permalink
fix fuzz test
Browse files Browse the repository at this point in the history
  • Loading branch information
SebastianElvis committed Sep 10, 2024
1 parent 86bba3f commit 4f2096f
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 31 deletions.
53 changes: 34 additions & 19 deletions x/zoneconcierge/keeper/grpc_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,37 @@ import (
"google.golang.org/grpc/status"
)

var _ types.QueryServer = Keeper{}
// Querier is used as Keeper will have duplicate methods if used directly, and gRPC names take precedence over keeper
type Querier struct {
Keeper
}

var _ types.QueryServer = Querier{}

const maxQueryChainsInfoLimit = 100

func (k Keeper) Params(c context.Context, req *types.QueryParamsRequest) (*types.QueryParamsResponse, error) {
func validateRequest(req sdk.Msg) error {
if !types.EnableIntegration {
return types.ErrIntegrationDisabled
}
if req == nil {
return nil, status.Error(codes.InvalidArgument, "invalid request")
return status.Error(codes.InvalidArgument, "invalid request")
}
return nil
}

func (k Keeper) Params(c context.Context, req *types.QueryParamsRequest) (*types.QueryParamsResponse, error) {
if err := validateRequest(req); err != nil {
return nil, err
}
ctx := sdk.UnwrapSDKContext(c)

return &types.QueryParamsResponse{Params: k.GetParams(ctx)}, nil
}

func (k Keeper) ChainList(c context.Context, req *types.QueryChainListRequest) (*types.QueryChainListResponse, error) {
if req == nil {
return nil, status.Error(codes.InvalidArgument, "invalid request")
if err := validateRequest(req); err != nil {
return nil, err
}

ctx := sdk.UnwrapSDKContext(c)
Expand All @@ -51,8 +66,8 @@ func (k Keeper) ChainList(c context.Context, req *types.QueryChainListRequest) (

// ChainsInfo returns the latest info for a given list of chains
func (k Keeper) ChainsInfo(c context.Context, req *types.QueryChainsInfoRequest) (*types.QueryChainsInfoResponse, error) {
if req == nil {
return nil, status.Error(codes.InvalidArgument, "invalid request")
if err := validateRequest(req); err != nil {
return nil, err
}

// return if no chain IDs are provided
Expand Down Expand Up @@ -87,8 +102,8 @@ func (k Keeper) ChainsInfo(c context.Context, req *types.QueryChainsInfoRequest)

// Header returns the header and fork headers at a given height
func (k Keeper) Header(c context.Context, req *types.QueryHeaderRequest) (*types.QueryHeaderResponse, error) {
if req == nil {
return nil, status.Error(codes.InvalidArgument, "invalid request")
if err := validateRequest(req); err != nil {
return nil, err
}

if len(req.ConsumerId) == 0 {
Expand All @@ -112,8 +127,8 @@ func (k Keeper) Header(c context.Context, req *types.QueryHeaderRequest) (*types

// EpochChainsInfo returns the latest info for list of chains in a given epoch
func (k Keeper) EpochChainsInfo(c context.Context, req *types.QueryEpochChainsInfoRequest) (*types.QueryEpochChainsInfoResponse, error) {
if req == nil {
return nil, status.Error(codes.InvalidArgument, "invalid request")
if err := validateRequest(req); err != nil {
return nil, err
}

// return if no chain IDs are provided
Expand Down Expand Up @@ -160,8 +175,8 @@ func (k Keeper) EpochChainsInfo(c context.Context, req *types.QueryEpochChainsIn

// ListHeaders returns all headers of a chain with given ID, with pagination support
func (k Keeper) ListHeaders(c context.Context, req *types.QueryListHeadersRequest) (*types.QueryListHeadersResponse, error) {
if req == nil {
return nil, status.Error(codes.InvalidArgument, "invalid request")
if err := validateRequest(req); err != nil {
return nil, err
}

if len(req.ConsumerId) == 0 {
Expand Down Expand Up @@ -192,8 +207,8 @@ func (k Keeper) ListHeaders(c context.Context, req *types.QueryListHeadersReques
// ListEpochHeaders returns all headers of a chain with given ID
// TODO: support pagination in this RPC
func (k Keeper) ListEpochHeaders(c context.Context, req *types.QueryListEpochHeadersRequest) (*types.QueryListEpochHeadersResponse, error) {
if req == nil {
return nil, status.Error(codes.InvalidArgument, "invalid request")
if err := validateRequest(req); err != nil {
return nil, err
}

if len(req.ConsumerId) == 0 {
Expand All @@ -215,8 +230,8 @@ func (k Keeper) ListEpochHeaders(c context.Context, req *types.QueryListEpochHea

// FinalizedChainsInfo returns the finalized info for a given list of chains
func (k Keeper) FinalizedChainsInfo(c context.Context, req *types.QueryFinalizedChainsInfoRequest) (*types.QueryFinalizedChainsInfoResponse, error) {
if req == nil {
return nil, status.Error(codes.InvalidArgument, "invalid request")
if err := validateRequest(req); err != nil {
return nil, err
}

// return if no chain IDs are provided
Expand Down Expand Up @@ -305,8 +320,8 @@ func (k Keeper) FinalizedChainsInfo(c context.Context, req *types.QueryFinalized
}

func (k Keeper) FinalizedChainInfoUntilHeight(c context.Context, req *types.QueryFinalizedChainInfoUntilHeightRequest) (*types.QueryFinalizedChainInfoUntilHeightResponse, error) {
if req == nil {
return nil, status.Error(codes.InvalidArgument, "invalid request")
if err := validateRequest(req); err != nil {
return nil, err
}

if len(req.ConsumerId) == 0 {
Expand Down
12 changes: 8 additions & 4 deletions x/zoneconcierge/keeper/keeper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/babylonlabs-io/babylon/testutil/datagen"
testhelper "github.com/babylonlabs-io/babylon/testutil/helper"
zckeeper "github.com/babylonlabs-io/babylon/x/zoneconcierge/keeper"
"github.com/babylonlabs-io/babylon/x/zoneconcierge/types"
zctypes "github.com/babylonlabs-io/babylon/x/zoneconcierge/types"
"github.com/cosmos/cosmos-sdk/baseapp"
ibctmtypes "github.com/cosmos/ibc-go/v8/modules/light-clients/07-tendermint"
Expand Down Expand Up @@ -93,6 +94,9 @@ func FuzzFeatureGate(f *testing.F) {
*/
// Get zone concierge query client
zcQueryHelper := baseapp.NewQueryServerTestHelper(ctx, helper.App.InterfaceRegistry())
querier := zckeeper.Querier{Keeper: zcKeeper}
types.RegisterQueryServer(zcQueryHelper, querier)

queryClient := zctypes.NewQueryClient(zcQueryHelper)

// Test GetParams query
Expand All @@ -102,23 +106,23 @@ func FuzzFeatureGate(f *testing.F) {
require.NoError(t, err, "Params query should work when EnableIntegration is true")
} else {
require.Error(t, err, "Params query should be blocked when EnableIntegration is false")
require.Contains(t, err.Error(), "handler not found for /babylon.zoneconcierge.v1.Query/Params")
require.ErrorIs(t, err, types.ErrIntegrationDisabled)
}

/*
Ensure msg server is feature gated
*/
msgClient := zctypes.NewMsgClient(zcQueryHelper)
msgSrvr := zckeeper.NewMsgServerImpl(zcKeeper)
msgReq := &zctypes.MsgUpdateParams{
Authority: helper.App.GovKeeper.GetGovernanceAccount(ctx).GetAddress().String(),
Params: zctypes.DefaultParams(),
}
_, err = msgClient.UpdateParams(ctx, msgReq)
_, err = msgSrvr.UpdateParams(ctx, msgReq)
if currentEnableIntegration {
require.NoError(t, err, "MsgUpdateParams should work when EnableIntegration is true")
} else {
require.Error(t, err, "MsgUpdateParams should be blocked when EnableIntegration is false")
require.Contains(t, err.Error(), "handler not found for /babylon.zoneconcierge.v1.Msg/UpdateParams")
require.ErrorIs(t, err, types.ErrIntegrationDisabled)
}

/*
Expand Down
3 changes: 3 additions & 0 deletions x/zoneconcierge/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ var _ types.MsgServer = msgServer{}

// UpdateParams updates the params
func (ms msgServer) UpdateParams(goCtx context.Context, req *types.MsgUpdateParams) (*types.MsgUpdateParamsResponse, error) {
if !types.EnableIntegration {
return nil, types.ErrIntegrationDisabled
}
if ms.authority != req.Authority {
return nil, errorsmod.Wrapf(govtypes.ErrInvalidSigner, "invalid authority; expected %s, got %s", ms.authority, req.Authority)
}
Expand Down
6 changes: 2 additions & 4 deletions x/zoneconcierge/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,8 @@ func (AppModule) QuerierRoute() string { return types.RouterKey }

// RegisterServices registers a gRPC query service to respond to the module-specific gRPC queries
func (am AppModule) RegisterServices(cfg module.Configurator) {
if types.EnableIntegration {
types.RegisterMsgServer(cfg.MsgServer(), keeper.NewMsgServerImpl(am.keeper))
types.RegisterQueryServer(cfg.QueryServer(), am.keeper)
}
types.RegisterMsgServer(cfg.MsgServer(), keeper.NewMsgServerImpl(am.keeper))
types.RegisterQueryServer(cfg.QueryServer(), am.keeper)
}

// RegisterInvariants registers the invariants of the module. If an invariant deviates from its predicted value, the InvariantRegistry triggers appropriate logic (most often the chain will be halted)
Expand Down
13 changes: 9 additions & 4 deletions x/zoneconcierge/types/btc_timestamp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@ import (
"github.com/babylonlabs-io/babylon/x/zoneconcierge/types"
)

func init() {
types.EnableIntegration = true
}

func signBLSWithBitmap(blsSKs []bls12381.PrivateKey, bm bitmap.Bitmap, msg []byte) (bls12381.Signature, error) {
sigs := []bls12381.Signature{}
for i := 0; i < len(blsSKs); i++ {
Expand All @@ -38,6 +34,15 @@ func FuzzBTCTimestamp(f *testing.F) {
datagen.AddRandomSeedsToFuzzer(f, 10)

f.Fuzz(func(t *testing.T, seed int64) {
// Save the original value of EnableIntegration
originalEnableIntegration := types.EnableIntegration
// Restore the original value after the test
defer func() {
types.EnableIntegration = originalEnableIntegration
}()
// Set EnableIntegration to true
types.EnableIntegration = true

r := rand.New(rand.NewSource(seed))
// generate the validator set with 10 validators as genesis
genesisValSet, privSigner, err := datagen.GenesisValidatorSetWithPrivSigner(10)
Expand Down

0 comments on commit 4f2096f

Please sign in to comment.