diff --git a/x/zoneconcierge/keeper/grpc_query.go b/x/zoneconcierge/keeper/grpc_query.go index 5b0e7ee62..6a87ab455 100644 --- a/x/zoneconcierge/keeper/grpc_query.go +++ b/x/zoneconcierge/keeper/grpc_query.go @@ -11,13 +11,28 @@ import ( "google.golang.org/grpc/status" ) -var _ types.QueryServer = Keeper{} +// Querier is used as Keeper will have duplicate methods if used directly, and gRPC names take precedence over keeper +type Querier struct { + Keeper +} + +var _ types.QueryServer = Querier{} const maxQueryChainsInfoLimit = 100 -func (k Keeper) Params(c context.Context, req *types.QueryParamsRequest) (*types.QueryParamsResponse, error) { +func validateRequest(req sdk.Msg) error { + if !types.EnableIntegration { + return types.ErrIntegrationDisabled + } if req == nil { - return nil, status.Error(codes.InvalidArgument, "invalid request") + return status.Error(codes.InvalidArgument, "invalid request") + } + return nil +} + +func (k Keeper) Params(c context.Context, req *types.QueryParamsRequest) (*types.QueryParamsResponse, error) { + if err := validateRequest(req); err != nil { + return nil, err } ctx := sdk.UnwrapSDKContext(c) @@ -25,8 +40,8 @@ func (k Keeper) Params(c context.Context, req *types.QueryParamsRequest) (*types } func (k Keeper) ChainList(c context.Context, req *types.QueryChainListRequest) (*types.QueryChainListResponse, error) { - if req == nil { - return nil, status.Error(codes.InvalidArgument, "invalid request") + if err := validateRequest(req); err != nil { + return nil, err } ctx := sdk.UnwrapSDKContext(c) @@ -51,8 +66,8 @@ func (k Keeper) ChainList(c context.Context, req *types.QueryChainListRequest) ( // ChainsInfo returns the latest info for a given list of chains func (k Keeper) ChainsInfo(c context.Context, req *types.QueryChainsInfoRequest) (*types.QueryChainsInfoResponse, error) { - if req == nil { - return nil, status.Error(codes.InvalidArgument, "invalid request") + if err := validateRequest(req); err != nil { + return nil, err } // return if no chain IDs are provided @@ -87,8 +102,8 @@ func (k Keeper) ChainsInfo(c context.Context, req *types.QueryChainsInfoRequest) // Header returns the header and fork headers at a given height func (k Keeper) Header(c context.Context, req *types.QueryHeaderRequest) (*types.QueryHeaderResponse, error) { - if req == nil { - return nil, status.Error(codes.InvalidArgument, "invalid request") + if err := validateRequest(req); err != nil { + return nil, err } if len(req.ConsumerId) == 0 { @@ -112,8 +127,8 @@ func (k Keeper) Header(c context.Context, req *types.QueryHeaderRequest) (*types // EpochChainsInfo returns the latest info for list of chains in a given epoch func (k Keeper) EpochChainsInfo(c context.Context, req *types.QueryEpochChainsInfoRequest) (*types.QueryEpochChainsInfoResponse, error) { - if req == nil { - return nil, status.Error(codes.InvalidArgument, "invalid request") + if err := validateRequest(req); err != nil { + return nil, err } // return if no chain IDs are provided @@ -160,8 +175,8 @@ func (k Keeper) EpochChainsInfo(c context.Context, req *types.QueryEpochChainsIn // ListHeaders returns all headers of a chain with given ID, with pagination support func (k Keeper) ListHeaders(c context.Context, req *types.QueryListHeadersRequest) (*types.QueryListHeadersResponse, error) { - if req == nil { - return nil, status.Error(codes.InvalidArgument, "invalid request") + if err := validateRequest(req); err != nil { + return nil, err } if len(req.ConsumerId) == 0 { @@ -192,8 +207,8 @@ func (k Keeper) ListHeaders(c context.Context, req *types.QueryListHeadersReques // ListEpochHeaders returns all headers of a chain with given ID // TODO: support pagination in this RPC func (k Keeper) ListEpochHeaders(c context.Context, req *types.QueryListEpochHeadersRequest) (*types.QueryListEpochHeadersResponse, error) { - if req == nil { - return nil, status.Error(codes.InvalidArgument, "invalid request") + if err := validateRequest(req); err != nil { + return nil, err } if len(req.ConsumerId) == 0 { @@ -215,8 +230,8 @@ func (k Keeper) ListEpochHeaders(c context.Context, req *types.QueryListEpochHea // FinalizedChainsInfo returns the finalized info for a given list of chains func (k Keeper) FinalizedChainsInfo(c context.Context, req *types.QueryFinalizedChainsInfoRequest) (*types.QueryFinalizedChainsInfoResponse, error) { - if req == nil { - return nil, status.Error(codes.InvalidArgument, "invalid request") + if err := validateRequest(req); err != nil { + return nil, err } // return if no chain IDs are provided @@ -305,8 +320,8 @@ func (k Keeper) FinalizedChainsInfo(c context.Context, req *types.QueryFinalized } func (k Keeper) FinalizedChainInfoUntilHeight(c context.Context, req *types.QueryFinalizedChainInfoUntilHeightRequest) (*types.QueryFinalizedChainInfoUntilHeightResponse, error) { - if req == nil { - return nil, status.Error(codes.InvalidArgument, "invalid request") + if err := validateRequest(req); err != nil { + return nil, err } if len(req.ConsumerId) == 0 { diff --git a/x/zoneconcierge/keeper/keeper_test.go b/x/zoneconcierge/keeper/keeper_test.go index b7cb8d906..583fb727c 100644 --- a/x/zoneconcierge/keeper/keeper_test.go +++ b/x/zoneconcierge/keeper/keeper_test.go @@ -8,6 +8,7 @@ import ( "github.com/babylonlabs-io/babylon/testutil/datagen" testhelper "github.com/babylonlabs-io/babylon/testutil/helper" zckeeper "github.com/babylonlabs-io/babylon/x/zoneconcierge/keeper" + "github.com/babylonlabs-io/babylon/x/zoneconcierge/types" zctypes "github.com/babylonlabs-io/babylon/x/zoneconcierge/types" "github.com/cosmos/cosmos-sdk/baseapp" ibctmtypes "github.com/cosmos/ibc-go/v8/modules/light-clients/07-tendermint" @@ -93,6 +94,9 @@ func FuzzFeatureGate(f *testing.F) { */ // Get zone concierge query client zcQueryHelper := baseapp.NewQueryServerTestHelper(ctx, helper.App.InterfaceRegistry()) + querier := zckeeper.Querier{Keeper: zcKeeper} + types.RegisterQueryServer(zcQueryHelper, querier) + queryClient := zctypes.NewQueryClient(zcQueryHelper) // Test GetParams query @@ -102,23 +106,23 @@ func FuzzFeatureGate(f *testing.F) { require.NoError(t, err, "Params query should work when EnableIntegration is true") } else { require.Error(t, err, "Params query should be blocked when EnableIntegration is false") - require.Contains(t, err.Error(), "handler not found for /babylon.zoneconcierge.v1.Query/Params") + require.ErrorIs(t, err, types.ErrIntegrationDisabled) } /* Ensure msg server is feature gated */ - msgClient := zctypes.NewMsgClient(zcQueryHelper) + msgSrvr := zckeeper.NewMsgServerImpl(zcKeeper) msgReq := &zctypes.MsgUpdateParams{ Authority: helper.App.GovKeeper.GetGovernanceAccount(ctx).GetAddress().String(), Params: zctypes.DefaultParams(), } - _, err = msgClient.UpdateParams(ctx, msgReq) + _, err = msgSrvr.UpdateParams(ctx, msgReq) if currentEnableIntegration { require.NoError(t, err, "MsgUpdateParams should work when EnableIntegration is true") } else { require.Error(t, err, "MsgUpdateParams should be blocked when EnableIntegration is false") - require.Contains(t, err.Error(), "handler not found for /babylon.zoneconcierge.v1.Msg/UpdateParams") + require.ErrorIs(t, err, types.ErrIntegrationDisabled) } /* diff --git a/x/zoneconcierge/keeper/msg_server.go b/x/zoneconcierge/keeper/msg_server.go index 26d5e9c96..1e5189eb6 100644 --- a/x/zoneconcierge/keeper/msg_server.go +++ b/x/zoneconcierge/keeper/msg_server.go @@ -23,6 +23,9 @@ var _ types.MsgServer = msgServer{} // UpdateParams updates the params func (ms msgServer) UpdateParams(goCtx context.Context, req *types.MsgUpdateParams) (*types.MsgUpdateParamsResponse, error) { + if !types.EnableIntegration { + return nil, types.ErrIntegrationDisabled + } if ms.authority != req.Authority { return nil, errorsmod.Wrapf(govtypes.ErrInvalidSigner, "invalid authority; expected %s, got %s", ms.authority, req.Authority) } diff --git a/x/zoneconcierge/module.go b/x/zoneconcierge/module.go index 417a16628..4f0ad4b10 100644 --- a/x/zoneconcierge/module.go +++ b/x/zoneconcierge/module.go @@ -120,10 +120,8 @@ func (AppModule) QuerierRoute() string { return types.RouterKey } // RegisterServices registers a gRPC query service to respond to the module-specific gRPC queries func (am AppModule) RegisterServices(cfg module.Configurator) { - if types.EnableIntegration { - types.RegisterMsgServer(cfg.MsgServer(), keeper.NewMsgServerImpl(am.keeper)) - types.RegisterQueryServer(cfg.QueryServer(), am.keeper) - } + types.RegisterMsgServer(cfg.MsgServer(), keeper.NewMsgServerImpl(am.keeper)) + types.RegisterQueryServer(cfg.QueryServer(), am.keeper) } // RegisterInvariants registers the invariants of the module. If an invariant deviates from its predicted value, the InvariantRegistry triggers appropriate logic (most often the chain will be halted) diff --git a/x/zoneconcierge/types/btc_timestamp_test.go b/x/zoneconcierge/types/btc_timestamp_test.go index c214caeda..3eabd062c 100644 --- a/x/zoneconcierge/types/btc_timestamp_test.go +++ b/x/zoneconcierge/types/btc_timestamp_test.go @@ -19,10 +19,6 @@ import ( "github.com/babylonlabs-io/babylon/x/zoneconcierge/types" ) -func init() { - types.EnableIntegration = true -} - func signBLSWithBitmap(blsSKs []bls12381.PrivateKey, bm bitmap.Bitmap, msg []byte) (bls12381.Signature, error) { sigs := []bls12381.Signature{} for i := 0; i < len(blsSKs); i++ { @@ -38,6 +34,15 @@ func FuzzBTCTimestamp(f *testing.F) { datagen.AddRandomSeedsToFuzzer(f, 10) f.Fuzz(func(t *testing.T, seed int64) { + // Save the original value of EnableIntegration + originalEnableIntegration := types.EnableIntegration + // Restore the original value after the test + defer func() { + types.EnableIntegration = originalEnableIntegration + }() + // Set EnableIntegration to true + types.EnableIntegration = true + r := rand.New(rand.NewSource(seed)) // generate the validator set with 10 validators as genesis genesisValSet, privSigner, err := datagen.GenesisValidatorSetWithPrivSigner(10)