From ef856239b288496f6faa7967e73127ac87ad9c39 Mon Sep 17 00:00:00 2001 From: Tyler <48813565+technicallyty@users.noreply.github.com> Date: Mon, 13 Jan 2025 13:13:13 -0800 Subject: [PATCH] refactor(server/v2): auto-gateway improvements (#23262) Co-authored-by: Alex | Interchain Labs (cherry picked from commit b461a3142af55f96554ab5d99c7e56d63b1be286) # Conflicts: # server/v2/api/grpcgateway/interceptor.go # server/v2/api/grpcgateway/server.go # server/v2/api/grpcgateway/uri.go # server/v2/api/grpcgateway/uri_test.go --- server/v2/api/grpcgateway/doc.go | 11 + server/v2/api/grpcgateway/interceptor.go | 274 +++++++++++++++ server/v2/api/grpcgateway/interceptor_test.go | 313 ++++++++++++++++++ server/v2/api/grpcgateway/server.go | 156 +++++++++ server/v2/api/grpcgateway/uri.go | 91 +++++ server/v2/api/grpcgateway/uri_test.go | 172 ++++++++++ tests/systemtests/bank_test.go | 2 +- tests/systemtests/distribution_test.go | 7 +- 8 files changed, 1021 insertions(+), 5 deletions(-) create mode 100644 server/v2/api/grpcgateway/doc.go create mode 100644 server/v2/api/grpcgateway/interceptor.go create mode 100644 server/v2/api/grpcgateway/interceptor_test.go create mode 100644 server/v2/api/grpcgateway/server.go create mode 100644 server/v2/api/grpcgateway/uri.go create mode 100644 server/v2/api/grpcgateway/uri_test.go diff --git a/server/v2/api/grpcgateway/doc.go b/server/v2/api/grpcgateway/doc.go new file mode 100644 index 000000000000..cbdce577f2fc --- /dev/null +++ b/server/v2/api/grpcgateway/doc.go @@ -0,0 +1,11 @@ +// Package grpcgateway provides a custom http mux that utilizes the global gogoproto registry to match +// grpc gateway requests to query handlers. POST requests with JSON bodies and GET requests with query params are supported. +// Wildcard endpoints (i.e. foo/bar/{baz}), as well as catch-all endpoints (i.e. foo/bar/{baz=**} are supported. Using +// header `x-cosmos-block-height` allows you to specify a height for the query. +// +// The URL matching logic is achieved by building regular expressions from the gateway HTTP annotations. These regular expressions +// are then used to match against incoming requests to the HTTP server. +// +// In cases where the custom http mux is unable to handle the query (i.e. no match found), the request will fall back to the +// ServeMux from github.com/grpc-ecosystem/grpc-gateway/runtime. +package grpcgateway diff --git a/server/v2/api/grpcgateway/interceptor.go b/server/v2/api/grpcgateway/interceptor.go new file mode 100644 index 000000000000..81d1c3e32f4b --- /dev/null +++ b/server/v2/api/grpcgateway/interceptor.go @@ -0,0 +1,274 @@ +package grpcgateway + +import ( + "errors" + "io" + "net/http" + "reflect" + "regexp" + "strconv" + "strings" + + gogoproto "github.com/cosmos/gogoproto/proto" + "github.com/grpc-ecosystem/grpc-gateway/runtime" + "github.com/grpc-ecosystem/grpc-gateway/utilities" + "github.com/mitchellh/mapstructure" + "google.golang.org/genproto/googleapis/api/annotations" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + + "cosmossdk.io/core/transaction" + "cosmossdk.io/log" + "cosmossdk.io/server/v2/appmanager" +) + +const MaxBodySize = 1 << 20 // 1 MB + +var _ http.Handler = &gatewayInterceptor[transaction.Tx]{} + +// queryMetadata holds information related to handling gateway queries. +type queryMetadata struct { + // queryInputProtoName is the proto name of the query's input type. + queryInputProtoName string + // wildcardKeyNames are the wildcard key names from the query's HTTP annotation. + // for example /foo/bar/{baz}/{qux} would produce []string{"baz", "qux"} + // this is used for building the query's parameter map. + wildcardKeyNames []string +} + +// gatewayInterceptor handles routing grpc-gateway queries to the app manager's query router. +type gatewayInterceptor[T transaction.Tx] struct { + logger log.Logger + // gateway is the fallback grpc gateway mux handler. + gateway *runtime.ServeMux + + matcher uriMatcher + + // appManager is used to route queries to the application. + appManager appmanager.AppManager[T] +} + +// newGatewayInterceptor creates a new gatewayInterceptor. +func newGatewayInterceptor[T transaction.Tx](logger log.Logger, gateway *runtime.ServeMux, am appmanager.AppManager[T]) (*gatewayInterceptor[T], error) { + getMapping, err := getHTTPGetAnnotationMapping() + if err != nil { + return nil, err + } + // convert the mapping to regular expressions for URL matching. + wildcardMatchers, simpleMatchers := createRegexMapping(logger, getMapping) + matcher := uriMatcher{ + wildcardURIMatchers: wildcardMatchers, + simpleMatchers: simpleMatchers, + } + return &gatewayInterceptor[T]{ + logger: logger, + gateway: gateway, + matcher: matcher, + appManager: am, + }, nil +} + +// ServeHTTP implements the http.Handler interface. This method will attempt to match request URIs to its internal mapping +// of gateway HTTP annotations. If no match can be made, it falls back to the runtime gateway server mux. +func (g *gatewayInterceptor[T]) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + g.logger.Debug("received grpc-gateway request", "request_uri", request.RequestURI) + match := g.matcher.matchURL(request.URL) + if match == nil { + // no match cases fall back to gateway mux. + g.gateway.ServeHTTP(writer, request) + return + } + + g.logger.Debug("matched request", "query_input", match.QueryInputName) + + in, out := runtime.MarshalerForRequest(g.gateway, request) + + // extract the proto message type. + msgType := gogoproto.MessageType(match.QueryInputName) + msg, ok := reflect.New(msgType.Elem()).Interface().(gogoproto.Message) + if !ok { + runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, status.Errorf(codes.Internal, "unable to to create gogoproto message from query input name %s", match.QueryInputName)) + return + } + + // msg population based on http method. + var inputMsg gogoproto.Message + var err error + switch request.Method { + case http.MethodGet: + inputMsg, err = g.createMessageFromGetRequest(request, msg, match.Params) + case http.MethodPost: + inputMsg, err = g.createMessageFromPostRequest(in, request, msg) + default: + runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, status.Error(codes.InvalidArgument, "HTTP method was not POST or GET")) + return + } + if err != nil { + // the errors returned from the message creation methods return status errors. no need to make one here. + runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, err) + return + } + + // get the height from the header. + var height uint64 + heightStr := request.Header.Get(GRPCBlockHeightHeader) + heightStr = strings.Trim(heightStr, `\"`) + if heightStr != "" && heightStr != "latest" { + height, err = strconv.ParseUint(heightStr, 10, 64) + if err != nil { + runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, status.Errorf(codes.InvalidArgument, "invalid height in header: %s", heightStr)) + return + } + } + + responseMsg, err := g.appManager.Query(request.Context(), height, inputMsg) + if err != nil { + // if we couldn't find a handler for this request, just fall back to the gateway mux. + if strings.Contains(err.Error(), "no handler") { + g.gateway.ServeHTTP(writer, request) + } else { + // for all other errors, we just return the error. + runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, err) + } + return + } + + // for no errors, we forward the response. + runtime.ForwardResponseMessage(request.Context(), g.gateway, out, writer, request, responseMsg) +} + +func (g *gatewayInterceptor[T]) createMessageFromPostRequest(marshaler runtime.Marshaler, req *http.Request, input gogoproto.Message) (gogoproto.Message, error) { + if req.ContentLength > MaxBodySize { + return nil, status.Errorf(codes.InvalidArgument, "request body too large: %d bytes, max=%d", req.ContentLength, MaxBodySize) + } + newReader, err := utilities.IOReaderFactory(req.Body) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "%v", err) + } + + if err = marshaler.NewDecoder(newReader()).Decode(input); err != nil && !errors.Is(err, io.EOF) { + return nil, status.Errorf(codes.InvalidArgument, "%v", err) + } + + return input, nil +} + +func (g *gatewayInterceptor[T]) createMessageFromGetRequest(req *http.Request, input gogoproto.Message, wildcardValues map[string]string) (gogoproto.Message, error) { + // decode the path wildcards into the message. + decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + Result: input, + TagName: "json", + WeaklyTypedInput: true, + }) + if err != nil { + return nil, status.Error(codes.Internal, "failed to create message decoder") + } + if err := decoder.Decode(wildcardValues); err != nil { + return nil, status.Error(codes.InvalidArgument, err.Error()) + } + + if err = req.ParseForm(); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "%v", err) + } + + filter := filterFromPathParams(wildcardValues) + err = runtime.PopulateQueryParameters(input, req.Form, filter) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "%v", err) + } + + return input, err +} + +func filterFromPathParams(pathParams map[string]string) *utilities.DoubleArray { + var prefixPaths [][]string + + for k := range pathParams { + prefixPaths = append(prefixPaths, []string{k}) + } + + return utilities.NewDoubleArray(prefixPaths) +} + +// getHTTPGetAnnotationMapping returns a mapping of RPC Method HTTP GET annotation to the RPC Handler's Request Input type full name. +// +// example: "/cosmos/auth/v1beta1/account_info/{address}":"cosmos.auth.v1beta1.Query.AccountInfo" +func getHTTPGetAnnotationMapping() (map[string]string, error) { + protoFiles, err := gogoproto.MergedRegistry() + if err != nil { + return nil, err + } + + annotationToQueryInputName := make(map[string]string) + protoFiles.RangeFiles(func(fd protoreflect.FileDescriptor) bool { + for i := 0; i < fd.Services().Len(); i++ { + serviceDesc := fd.Services().Get(i) + for j := 0; j < serviceDesc.Methods().Len(); j++ { + methodDesc := serviceDesc.Methods().Get(j) + httpExtension := proto.GetExtension(methodDesc.Options(), annotations.E_Http) + if httpExtension == nil { + continue + } + + httpRule, ok := httpExtension.(*annotations.HttpRule) + if !ok || httpRule == nil { + continue + } + queryInputName := string(methodDesc.Input().FullName()) + annotations := append(httpRule.GetAdditionalBindings(), httpRule) + for _, a := range annotations { + if httpAnnotation := a.GetGet(); httpAnnotation != "" { + annotationToQueryInputName[httpAnnotation] = queryInputName + } + if httpAnnotation := a.GetPost(); httpAnnotation != "" { + annotationToQueryInputName[httpAnnotation] = queryInputName + } + } + } + } + return true + }) + return annotationToQueryInputName, nil +} + +// createRegexMapping converts the annotationMapping (HTTP annotation -> query input type name) to a +// map of regular expressions for that HTTP annotation pattern, to queryMetadata. +func createRegexMapping(logger log.Logger, annotationMapping map[string]string) (map[*regexp.Regexp]queryMetadata, map[string]queryMetadata) { + wildcardMatchers := make(map[*regexp.Regexp]queryMetadata) + // seen patterns is a map of URI patterns to annotations. for simple queries (no wildcards) the annotation is used + // for the key. + seenPatterns := make(map[string]string) + simpleMatchers := make(map[string]queryMetadata) + + for annotation, queryInputName := range annotationMapping { + pattern, wildcardNames := patternToRegex(annotation) + if len(wildcardNames) == 0 { + if otherAnnotation, ok := seenPatterns[annotation]; ok { + // TODO: eventually we want this to error, but there is currently a duplicate in the protobuf. + // see: https://github.com/cosmos/cosmos-sdk/issues/23281 + logger.Warn("duplicate HTTP annotation found", "annotation1", annotation, "annotation2", otherAnnotation, "query_input_name", queryInputName) + } + simpleMatchers[annotation] = queryMetadata{ + queryInputProtoName: queryInputName, + wildcardKeyNames: nil, + } + seenPatterns[annotation] = annotation + } else { + reg := regexp.MustCompile(pattern) + if otherAnnotation, ok := seenPatterns[pattern]; ok { + // TODO: eventually we want this to error, but there is currently a duplicate in the protobuf. + // see: https://github.com/cosmos/cosmos-sdk/issues/23281 + logger.Warn("duplicate HTTP annotation found", "annotation1", annotation, "annotation2", otherAnnotation, "query_input_name", queryInputName) + } + wildcardMatchers[reg] = queryMetadata{ + queryInputProtoName: queryInputName, + wildcardKeyNames: wildcardNames, + } + seenPatterns[pattern] = annotation + + } + } + return wildcardMatchers, simpleMatchers +} diff --git a/server/v2/api/grpcgateway/interceptor_test.go b/server/v2/api/grpcgateway/interceptor_test.go new file mode 100644 index 000000000000..80512f2c21df --- /dev/null +++ b/server/v2/api/grpcgateway/interceptor_test.go @@ -0,0 +1,313 @@ +package grpcgateway + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + gogoproto "github.com/cosmos/gogoproto/proto" + "github.com/grpc-ecosystem/grpc-gateway/runtime" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "cosmossdk.io/core/transaction" + "cosmossdk.io/log" +) + +func Test_createRegexMapping(t *testing.T) { + tests := []struct { + name string + annotations map[string]string + expectedRegex int + expectedSimple int + wantWarn bool + }{ + { + name: "no annotations should not warn", + }, + { + name: "expected correct amount of regex and simple matchers", + annotations: map[string]string{ + "/foo/bar/baz": "", + "/foo/{bar}/baz": "", + "/foo/bar/bell": "", + }, + expectedRegex: 1, + expectedSimple: 2, + }, + { + name: "different annotations should not warn", + annotations: map[string]string{ + "/foo/bar/{baz}": "", + "/crypto/{currency}": "", + }, + expectedRegex: 2, + }, + { + name: "duplicate annotations should warn", + annotations: map[string]string{ + "/hello/{world}": "", + "/hello/{developers}": "", + }, + expectedRegex: 2, + wantWarn: true, + }, + } + buf := bytes.NewBuffer(nil) + logger := log.NewLogger(buf) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + regex, simple := createRegexMapping(logger, tt.annotations) + if tt.wantWarn { + require.NotEmpty(t, buf.String()) + } else { + require.Empty(t, buf.String()) + } + require.Equal(t, tt.expectedRegex, len(regex)) + require.Equal(t, tt.expectedSimple, len(simple)) + }) + } +} + +func TestCreateMessageFromGetRequest(t *testing.T) { + gogoproto.RegisterType(&DummyProto{}, dummyProtoName) + + testCases := []struct { + name string + request func() *http.Request + wildcardValues map[string]string + expected *DummyProto + wantErr bool + errCode codes.Code + }{ + { + name: "simple wildcard + query params", + request: func() *http.Request { + // GET with query params: ?bar=true&baz=42&denoms=atom&denoms=osmo + // Also nested pagination params: page.limit=100, page.nest.foo=999 + req := httptest.NewRequest( + http.MethodGet, + "/dummy?bar=true&baz=42&denoms=atom&denoms=osmo&page.limit=100&page.nest.foo=999", + nil, + ) + return req + }, + wildcardValues: map[string]string{ + "foo": "wildFooValue", // from path wildcard e.g. /dummy/{foo} + }, + expected: &DummyProto{ + Foo: "wildFooValue", + Bar: true, + Baz: 42, + Denoms: []string{"atom", "osmo"}, + Page: &Pagination{ + Limit: 100, + Nest: &Nested{ + Foo: 999, + }, + }, + }, + wantErr: false, + }, + { + name: "invalid integer in query param", + request: func() *http.Request { + req := httptest.NewRequest( + http.MethodGet, + "/dummy?baz=notanint", + nil, + ) + return req + }, + wildcardValues: map[string]string{}, + expected: &DummyProto{}, // won't get populated + wantErr: true, + errCode: codes.InvalidArgument, + }, + { + name: "no query params, but wildcard set", + request: func() *http.Request { + // No query params. Only the wildcard. + req := httptest.NewRequest( + http.MethodGet, + "/dummy", + nil, + ) + return req + }, + wildcardValues: map[string]string{ + "foo": "barFromWildcard", + }, + expected: &DummyProto{ + Foo: "barFromWildcard", + }, + wantErr: false, + }, + } + + // We only need a minimal gatewayInterceptor instance to call createMessageFromGetRequest, + // so it's fine to leave most fields nil for this unit test. + g := &gatewayInterceptor[transaction.Tx]{} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := tc.request() + + inputMsg := &DummyProto{} + gotMsg, err := g.createMessageFromGetRequest( + req, + inputMsg, + tc.wildcardValues, + ) + + if tc.wantErr { + require.Error(t, err, "expected error but got none") + st, ok := status.FromError(err) + if ok && tc.errCode != codes.OK { + require.Equal(t, tc.errCode, st.Code()) + } + } else { + require.NoError(t, err, "unexpected error") + require.Equal(t, tc.expected, gotMsg, "message contents do not match expected") + } + }) + } +} + +func TestCreateMessageFromPostRequest(t *testing.T) { + gogoproto.RegisterType(&DummyProto{}, dummyProtoName) + gogoproto.RegisterType(&Pagination{}, "pagination") + gogoproto.RegisterType(&Nested{}, "nested") + + testCases := []struct { + name string + body any + wantErr bool + errCode codes.Code + expected *DummyProto + }{ + { + name: "valid JSON body with nested fields", + body: map[string]any{ + "foo": "postFoo", + "bar": true, + "baz": 42, + "denoms": []string{"atom", "osmo"}, + "page": map[string]any{ + "limit": 100, + "nest": map[string]any{ + "foo": 999, + }, + }, + }, + wantErr: false, + expected: &DummyProto{ + Foo: "postFoo", + Bar: true, + Baz: 42, + Denoms: []string{"atom", "osmo"}, + Page: &Pagination{ + Limit: 100, + Nest: &Nested{ + Foo: 999, + }, + }, + }, + }, + { + name: "invalid JSON structure", + // Provide a broken JSON string: + body: `{"foo": "bad json", "extra": "not closed"`, + wantErr: true, + errCode: codes.InvalidArgument, + }, + { + name: "empty JSON object", + body: map[string]any{}, + wantErr: false, + expected: &DummyProto{}, // all fields remain zeroed + }, + } + + g := &gatewayInterceptor[transaction.Tx]{} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var reqBody []byte + switch typedBody := tc.body.(type) { + case string: + // This might be invalid JSON we intentionally want to test + reqBody = []byte(typedBody) + default: + // Marshal the given any into JSON + b, err := json.Marshal(typedBody) + require.NoError(t, err, "failed to marshal test body to JSON") + reqBody = b + } + + req := httptest.NewRequest(http.MethodPost, "/dummy", bytes.NewReader(reqBody)) + + inputMsg := &DummyProto{} + gotMsg, err := g.createMessageFromPostRequest( + &runtime.JSONPb{}, // JSONPb marshaler + req, + inputMsg, + ) + + if tc.wantErr { + require.Error(t, err, "expected an error but got none") + // Optionally verify the gRPC status code + st, ok := status.FromError(err) + if ok && tc.errCode != codes.OK { + require.Equal(t, tc.errCode, st.Code()) + } + } else { + require.NoError(t, err, "did not expect an error") + require.Equal(t, tc.expected, gotMsg) + } + }) + } +} + +/* +--- Testing Types --- +*/ +type Nested struct { + Foo int32 `protobuf:"varint,1,opt,name=foo,proto3" json:"foo,omitempty"` +} + +func (n Nested) Reset() {} + +func (n Nested) String() string { return "" } + +func (n Nested) ProtoMessage() {} + +type Pagination struct { + Limit int32 `protobuf:"varint,1,opt,name=limit,proto3" json:"limit,omitempty"` + Nest *Nested `protobuf:"bytes,2,opt,name=nest,proto3" json:"nest,omitempty"` +} + +func (p Pagination) Reset() {} + +func (p Pagination) String() string { return "" } + +func (p Pagination) ProtoMessage() {} + +const dummyProtoName = "dummy" + +type DummyProto struct { + Foo string `protobuf:"bytes,1,opt,name=foo,proto3" json:"foo,omitempty"` + Bar bool `protobuf:"varint,2,opt,name=bar,proto3" json:"bar,omitempty"` + Baz int32 `protobuf:"varint,3,opt,name=baz,proto3" json:"baz,omitempty"` + Denoms []string `protobuf:"bytes,4,rep,name=denoms,proto3" json:"denoms,omitempty"` + Page *Pagination `protobuf:"bytes,5,opt,name=page,proto3" json:"page,omitempty"` +} + +func (d DummyProto) Reset() {} + +func (d DummyProto) String() string { return dummyProtoName } + +func (d DummyProto) ProtoMessage() {} diff --git a/server/v2/api/grpcgateway/server.go b/server/v2/api/grpcgateway/server.go new file mode 100644 index 000000000000..59fb1b2ff1aa --- /dev/null +++ b/server/v2/api/grpcgateway/server.go @@ -0,0 +1,156 @@ +package grpcgateway + +import ( + "context" + "fmt" + "net/http" + "strings" + + gateway "github.com/cosmos/gogogateway" + "github.com/cosmos/gogoproto/jsonpb" + "github.com/grpc-ecosystem/grpc-gateway/runtime" + + "cosmossdk.io/core/server" + "cosmossdk.io/core/transaction" + "cosmossdk.io/log" + serverv2 "cosmossdk.io/server/v2" + "cosmossdk.io/server/v2/appmanager" +) + +var ( + _ serverv2.ServerComponent[transaction.Tx] = (*Server[transaction.Tx])(nil) + _ serverv2.HasConfig = (*Server[transaction.Tx])(nil) +) + +const ServerName = "grpc-gateway" + +type Server[T transaction.Tx] struct { + logger log.Logger + config *Config + cfgOptions []CfgOption + + server *http.Server + GRPCGatewayRouter *runtime.ServeMux +} + +// New creates a new gRPC-gateway server. +func New[T transaction.Tx]( + logger log.Logger, + config server.ConfigMap, + ir jsonpb.AnyResolver, + appManager appmanager.AppManager[T], + cfgOptions ...CfgOption, +) (*Server[T], error) { + // The default JSON marshaller used by the gRPC-Gateway is unable to marshal non-nullable non-scalar fields. + // Using the gogo/gateway package with the gRPC-Gateway WithMarshaler option fixes the scalar field marshaling issue. + marshalerOption := &gateway.JSONPb{ + EmitDefaults: true, + Indent: "", + OrigName: true, + AnyResolver: ir, + } + + s := &Server[T]{ + GRPCGatewayRouter: runtime.NewServeMux( + // Custom marshaler option is required for gogo proto + runtime.WithMarshalerOption(runtime.MIMEWildcard, marshalerOption), + + // This is necessary to get error details properly + // marshaled in unary requests. + runtime.WithProtoErrorHandler(runtime.DefaultHTTPProtoErrorHandler), + + // Custom header uriMatcher for mapping request headers to + // GRPC metadata + runtime.WithIncomingHeaderMatcher(CustomGRPCHeaderMatcher), + ), + cfgOptions: cfgOptions, + } + + serverCfg := s.Config().(*Config) + if len(config) > 0 { + if err := serverv2.UnmarshalSubConfig(config, s.Name(), &serverCfg); err != nil { + return s, fmt.Errorf("failed to unmarshal config: %w", err) + } + } + + s.logger = logger.With(log.ModuleKey, s.Name()) + s.config = serverCfg + mux := http.NewServeMux() + interceptor, err := newGatewayInterceptor[T](logger, s.GRPCGatewayRouter, appManager) + if err != nil { + return nil, fmt.Errorf("failed to create grpc-gateway interceptor: %w", err) + } + mux.Handle("/", interceptor) + + s.server = &http.Server{ + Addr: s.config.Address, + Handler: mux, + } + return s, nil +} + +// NewWithConfigOptions creates a new gRPC-gateway server with the provided config options. +func NewWithConfigOptions[T transaction.Tx](opts ...CfgOption) *Server[T] { + return &Server[T]{ + cfgOptions: opts, + } +} + +func (s *Server[T]) Name() string { + return ServerName +} + +func (s *Server[T]) Config() any { + if s.config == nil || s.config.Address == "" { + cfg := DefaultConfig() + // overwrite the default config with the provided options + for _, opt := range s.cfgOptions { + opt(cfg) + } + + return cfg + } + + return s.config +} + +func (s *Server[T]) Start(ctx context.Context) error { + if !s.config.Enable { + s.logger.Info(fmt.Sprintf("%s server is disabled via config", s.Name())) + return nil + } + + s.logger.Info("starting gRPC-Gateway server...", "address", s.config.Address) + if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + return fmt.Errorf("failed to start gRPC-Gateway server: %w", err) + } + + return nil +} + +func (s *Server[T]) Stop(ctx context.Context) error { + if !s.config.Enable { + return nil + } + + s.logger.Info("stopping gRPC-Gateway server...", "address", s.config.Address) + return s.server.Shutdown(ctx) +} + +// GRPCBlockHeightHeader is the gRPC header for block height. +const GRPCBlockHeightHeader = "x-cosmos-block-height" + +// CustomGRPCHeaderMatcher for mapping request headers to +// GRPC metadata. +// HTTP headers that start with 'Grpc-Metadata-' are automatically mapped to +// gRPC metadata after removing prefix 'Grpc-Metadata-'. We can use this +// CustomGRPCHeaderMatcher if headers don't start with `Grpc-Metadata-` +func CustomGRPCHeaderMatcher(key string) (string, bool) { + switch strings.ToLower(key) { + case GRPCBlockHeightHeader: + return GRPCBlockHeightHeader, true + + default: + return runtime.DefaultHeaderMatcher(key) + } +} diff --git a/server/v2/api/grpcgateway/uri.go b/server/v2/api/grpcgateway/uri.go new file mode 100644 index 000000000000..f5ebc25668fa --- /dev/null +++ b/server/v2/api/grpcgateway/uri.go @@ -0,0 +1,91 @@ +package grpcgateway + +import ( + "net/url" + "regexp" + "strings" +) + +// uriMatcher provides functionality to match HTTP request URIs. +type uriMatcher struct { + // wildcardURIMatchers are used for complex URIs that involve wildcards (i.e. /foo/{bar}/baz) + wildcardURIMatchers map[*regexp.Regexp]queryMetadata + // simpleMatchers are used for simple URI's that have no wildcards (i.e. /foo/bar/baz). + simpleMatchers map[string]queryMetadata +} + +// uriMatch contains information related to a URI match. +type uriMatch struct { + // QueryInputName is the fully qualified name of the proto input type of the query rpc method. + QueryInputName string + + // Params are any wildcard params found in the request. + // + // example: /foo/bar/{baz} -> /foo/bar/hello = {"baz": "hello"} + Params map[string]string +} + +// matchURL attempts to find a match for the given URL. +// NOTE: if no match is found, nil is returned. +func (m uriMatcher) matchURL(u *url.URL) *uriMatch { + uriPath := strings.TrimRight(u.Path, "/") + params := make(map[string]string) + + // see if we can get a simple match first. + if qmd, ok := m.simpleMatchers[uriPath]; ok { + return &uriMatch{ + QueryInputName: qmd.queryInputProtoName, + Params: params, + } + } + + // try the complex matchers. + for reg, qmd := range m.wildcardURIMatchers { + matches := reg.FindStringSubmatch(uriPath) + switch { + case len(matches) == 1: + return &uriMatch{ + QueryInputName: qmd.queryInputProtoName, + Params: params, + } + case len(matches) > 1: + // first match is the URI, subsequent matches are the wild card values. + for i, name := range qmd.wildcardKeyNames { + params[name] = matches[i+1] + } + + return &uriMatch{ + QueryInputName: qmd.queryInputProtoName, + Params: params, + } + } + } + return nil +} + +// patternToRegex converts a URI pattern with wildcards to a regex pattern. +// Returns the regex pattern and a slice of wildcard names in order +func patternToRegex(pattern string) (string, []string) { + escaped := regexp.QuoteMeta(pattern) + var wildcardNames []string + + // extract and replace {param=**} patterns + r1 := regexp.MustCompile(`\\\{([^}]+?)=\\\*\\\*\\}`) + escaped = r1.ReplaceAllStringFunc(escaped, func(match string) string { + // extract wildcard name without the =** suffix + name := regexp.MustCompile(`\\\{(.+?)=`).FindStringSubmatch(match)[1] + wildcardNames = append(wildcardNames, name) + return "(.+)" + }) + + // extract and replace {param} patterns + r2 := regexp.MustCompile(`\\\{([^}]+)\\}`) + escaped = r2.ReplaceAllStringFunc(escaped, func(match string) string { + // extract wildcard name from the curl braces {}. + name := regexp.MustCompile(`\\\{(.*?)\\}`).FindStringSubmatch(match)[1] + wildcardNames = append(wildcardNames, name) + return "([^/]+)" + }) + + return "^" + escaped + "$", wildcardNames +} diff --git a/server/v2/api/grpcgateway/uri_test.go b/server/v2/api/grpcgateway/uri_test.go new file mode 100644 index 000000000000..24d51fd59e54 --- /dev/null +++ b/server/v2/api/grpcgateway/uri_test.go @@ -0,0 +1,172 @@ +package grpcgateway + +import ( + "net/url" + "os" + "regexp" + "testing" + + "github.com/stretchr/testify/require" + + "cosmossdk.io/log" +) + +func TestMatchURI(t *testing.T) { + testCases := []struct { + name string + uri string + mapping map[string]string + expected *uriMatch + }{ + { + name: "simple match, no wildcards", + uri: "https://localhost:8080/foo/bar", + mapping: map[string]string{"/foo/bar": "query.Bank"}, + expected: &uriMatch{QueryInputName: "query.Bank", Params: map[string]string{}}, + }, + { + name: "match with wildcard similar to simple match - simple", + uri: "https://localhost:8080/bank/supply/latest", + mapping: map[string]string{ + "/bank/supply/{height}": "queryBankHeight", + "/bank/supply/latest": "queryBankLatest", + }, + expected: &uriMatch{QueryInputName: "queryBankLatest", Params: map[string]string{}}, + }, + { + name: "match with wildcard similar to simple match - wildcard", + uri: "https://localhost:8080/bank/supply/52", + mapping: map[string]string{ + "/bank/supply/{height}": "queryBankHeight", + "/bank/supply/latest": "queryBankLatest", + }, + expected: &uriMatch{QueryInputName: "queryBankHeight", Params: map[string]string{"height": "52"}}, + }, + { + name: "wildcard match at the end", + uri: "https://localhost:8080/foo/bar/buzz", + mapping: map[string]string{"/foo/bar/{baz}": "bar"}, + expected: &uriMatch{ + QueryInputName: "bar", + Params: map[string]string{"baz": "buzz"}, + }, + }, + { + name: "wildcard match in the middle", + uri: "https://localhost:8080/foo/buzz/bar", + mapping: map[string]string{"/foo/{baz}/bar": "bar"}, + expected: &uriMatch{ + QueryInputName: "bar", + Params: map[string]string{"baz": "buzz"}, + }, + }, + { + name: "multiple wild cards", + uri: "https://localhost:8080/foo/bar/baz/buzz", + mapping: map[string]string{"/foo/bar/{q1}/{q2}": "bar"}, + expected: &uriMatch{ + QueryInputName: "bar", + Params: map[string]string{"q1": "baz", "q2": "buzz"}, + }, + }, + { + name: "catch-all wildcard", + uri: "https://localhost:8080/foo/bar/ibc/token/stuff", + mapping: map[string]string{"/foo/bar/{ibc_token=**}": "bar"}, + expected: &uriMatch{ + QueryInputName: "bar", + Params: map[string]string{"ibc_token": "ibc/token/stuff"}, + }, + }, + { + name: "no match should return nil", + uri: "https://localhost:8080/foo/bar", + mapping: map[string]string{"/bar/foo": "bar"}, + expected: nil, + }, + } + + logger := log.NewLogger(os.Stdout) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + u, err := url.Parse(tc.uri) + require.NoError(t, err) + + regexpMatchers, simpleMatchers := createRegexMapping(logger, tc.mapping) + matcher := uriMatcher{ + wildcardURIMatchers: regexpMatchers, + simpleMatchers: simpleMatchers, + } + + actual := matcher.matchURL(u) + require.Equal(t, tc.expected, actual) + }) + } +} + +func Test_patternToRegex(t *testing.T) { + tests := []struct { + name string + pattern string + wildcards []string + wildcardValues []string + shouldMatch string + shouldNotMatch []string + }{ + { + name: "simple match, no wildcards", + pattern: "/foo/bar/baz", + shouldMatch: "/foo/bar/baz", + shouldNotMatch: []string{"/foo/bar", "/foo", "/foo/bar/baz/boo"}, + }, + { + name: "match with wildcard", + pattern: "/foo/bar/{baz}", + wildcards: []string{"baz"}, + shouldMatch: "/foo/bar/hello", + wildcardValues: []string{"hello"}, + shouldNotMatch: []string{"/foo/bar", "/foo/bar/baz/boo"}, + }, + { + name: "match with multiple wildcards", + pattern: "/foo/{bar}/{baz}/meow", + wildcards: []string{"bar", "baz"}, + shouldMatch: "/foo/hello/world/meow", + wildcardValues: []string{"hello", "world"}, + shouldNotMatch: []string{"/foo/bar/baz/boo", "/foo/bar/baz"}, + }, + { + name: "match catch-all wildcard", + pattern: `/foo/bar/{baz=**}`, + wildcards: []string{"baz"}, + shouldMatch: `/foo/bar/this/is/a/long/wildcard`, + wildcardValues: []string{"this/is/a/long/wildcard"}, + shouldNotMatch: []string{"/foo/bar", "/foo", "/foo/baz/bar/long/wild/card"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + regString, wildcards := patternToRegex(tt.pattern) + // should produce the same wildcard keys + require.Equal(t, tt.wildcards, wildcards) + reg := regexp.MustCompile(regString) + + // handle the "should match" case. + matches := reg.FindStringSubmatch(tt.shouldMatch) + require.True(t, len(matches) > 0) // there should always be a match. + // when matches > 1, this means we got wildcard values to handle. the test should have wildcard values. + if len(matches) > 1 { + require.Greater(t, len(tt.wildcardValues), 0) + } + // matches[0] is the URL, everything else should be those wildcard values. + if len(tt.wildcardValues) > 0 { + require.Equal(t, matches[1:], tt.wildcardValues) + } + + // should never match these. + for _, notMatch := range tt.shouldNotMatch { + require.Len(t, reg.FindStringSubmatch(notMatch), 0) + } + }) + } +} diff --git a/tests/systemtests/bank_test.go b/tests/systemtests/bank_test.go index c5c45bca6d63..5dfb98c32fef 100644 --- a/tests/systemtests/bank_test.go +++ b/tests/systemtests/bank_test.go @@ -263,7 +263,7 @@ func TestBankGRPCQueries(t *testing.T) { "error when querying supply with height greater than block height", supplyUrl, map[string]string{ - blockHeightHeader: fmt.Sprintf("%d", blockHeight+5), + blockHeightHeader: fmt.Sprintf("%d", blockHeight+5000), }, http.StatusBadRequest, "invalid height", diff --git a/tests/systemtests/distribution_test.go b/tests/systemtests/distribution_test.go index 2602df526d91..9170e99728da 100644 --- a/tests/systemtests/distribution_test.go +++ b/tests/systemtests/distribution_test.go @@ -179,8 +179,7 @@ func TestDistrValidatorGRPCQueries(t *testing.T) { // test validator slashes grpc endpoint slashURL := baseurl + `/cosmos/distribution/v1beta1/validators/%s/slashes` - invalidStartingHeightOutput := `{"code":3, "message":"1 error(s) decoding:\n\n* cannot parse 'starting_height' as uint: strconv.ParseUint: parsing \"-3\": invalid syntax", "details":[]}` - invalidEndingHeightOutput := `{"code":3, "message":"1 error(s) decoding:\n\n* cannot parse 'ending_height' as uint: strconv.ParseUint: parsing \"-3\": invalid syntax", "details":[]}` + invalidHeightOutput := `{"code":"NUMBER", "details":[]interface {}{}, "message":"strconv.ParseUint: parsing \"NUMBER\": invalid syntax"}` if !systest.IsV2() { invalidStartingHeightOutput = `{"code":3, "message":"strconv.ParseUint: parsing \"-3\": invalid syntax", "details":[]}` @@ -192,13 +191,13 @@ func TestDistrValidatorGRPCQueries(t *testing.T) { Name: "invalid start height", Url: fmt.Sprintf(slashURL+`?starting_height=%s&ending_height=%s`, valOperAddr, "-3", "3"), ExpCode: http.StatusBadRequest, - ExpOut: invalidStartingHeightOutput, + ExpOut: invalidHeightOutput, }, { Name: "invalid end height", Url: fmt.Sprintf(slashURL+`?starting_height=%s&ending_height=%s`, valOperAddr, "1", "-3"), ExpCode: http.StatusBadRequest, - ExpOut: invalidEndingHeightOutput, + ExpOut: invalidHeightOutput, }, { Name: "valid request get slashes",