diff --git a/handlers.go b/handlers.go index 2bb3b86..23a6d27 100644 --- a/handlers.go +++ b/handlers.go @@ -23,8 +23,9 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" - "github.com/fullstorydev/grpcui/internal" "github.com/fullstorydev/grpcurl" + + "github.com/fullstorydev/grpcui/internal" ) // RPCInvokeHandler returns an HTTP handler that can be used to invoke RPCs. The @@ -66,8 +67,30 @@ type InvokeOptions struct { // of a bool "verbose" flag, so that additional logs may be added in the // future and the caller control how detailed those logs will be. Verbosity int + // Middlewares + Middlewares []Middleware } +type RPCRequest struct { + MethodName string + Conn grpc.ClientConnInterface + DescSource grpcurl.DescriptorSource + Headers http.Header + Body io.Reader + Options *InvokeOptions +} + +type RPCHandler func( + ctx context.Context, + req *RPCRequest, +) (*RPCResult, error) + +type Middleware func( + ctx context.Context, + req *RPCRequest, + next RPCHandler, +) (*RPCResult, error) + // RPCInvokeHandlerWithOptions is the same as RPCInvokeHandler except that it // accepts an additional argument, options. This can be used to add extra // request metadata to all RPCs invoked. @@ -96,7 +119,25 @@ func RPCInvokeHandlerWithOptions(ch grpc.ClientConnInterface, descs []*desc.Meth http.Error(w, "Failed to create descriptor source: "+err.Error(), http.StatusInternalServerError) return } - results, err := invokeRPC(r.Context(), method, ch, descSource, r.Header, r.Body, &options) + req := &RPCRequest{ + MethodName: method, + Conn: ch, + DescSource: descSource, + Headers: r.Header, + Body: r.Body, + Options: &options, + } + call := func(ctx context.Context, req *RPCRequest) (*RPCResult, error) { + return invokeRPC(ctx, method, ch, descSource, r.Header, r.Body, &options) + } + for i := len(options.Middlewares) - 1; i > 0; i-- { + mw := options.Middlewares[i] + c2 := call + call = func(ctx context.Context, req *RPCRequest) (*RPCResult, error) { + return mw(ctx, req, c2) + } + } + results, err := call(r.Context(), req) if err != nil { if _, ok := err.(errReadFail); ok { http.Error(w, "Failed to read request", 499) @@ -414,7 +455,7 @@ func (e errReadFail) Error() string { return e.err.Error() } -func invokeRPC(ctx context.Context, methodName string, ch grpc.ClientConnInterface, descSource grpcurl.DescriptorSource, reqHdrs http.Header, body io.Reader, options *InvokeOptions) (*rpcResult, error) { +func invokeRPC(ctx context.Context, methodName string, ch grpc.ClientConnInterface, descSource grpcurl.DescriptorSource, reqHdrs http.Header, body io.Reader, options *InvokeOptions) (*RPCResult, error) { js, err := io.ReadAll(body) if err != nil { return nil, errReadFail{err: err} @@ -425,7 +466,7 @@ func invokeRPC(ctx context.Context, methodName string, ch grpc.ClientConnInterfa return nil, errBadInput{err: err} } - reqStats := rpcRequestStats{ + reqStats := RPCRequestStats{ Total: len(input.Data), } requestFunc := func(m proto.Message) error { @@ -458,7 +499,7 @@ func invokeRPC(ctx context.Context, methodName string, ch grpc.ClientConnInterfa defer cancel() } - result := rpcResult{ + result := RPCResult{ descSource: descSource, emitDefaults: options.EmitDefaults, Requests: &reqStats, @@ -513,91 +554,91 @@ func (opts *InvokeOptions) computeHeaders(reqHdrs http.Header, webFormHdrs metad return result } -type rpcMetadata struct { +type RPCMetadata struct { Name string `json:"name"` Value string `json:"value"` } type rpcInput struct { TimeoutSeconds float32 `json:"timeout_seconds"` - Metadata []rpcMetadata `json:"metadata"` + Metadata []RPCMetadata `json:"metadata"` Data []json.RawMessage `json:"data"` } -type rpcResponseElement struct { +type RPCResponseElement struct { Data json.RawMessage `json:"message"` IsError bool `json:"isError"` } -type rpcRequestStats struct { +type RPCRequestStats struct { Total int `json:"total"` Sent int `json:"sent"` } -type rpcError struct { +type RPCError struct { Code uint32 `json:"code"` Name string `json:"name"` Message string `json:"message"` - Details []rpcResponseElement `json:"details"` + Details []RPCResponseElement `json:"details"` } -type rpcResult struct { +type RPCResult struct { descSource grpcurl.DescriptorSource emitDefaults bool - Headers []rpcMetadata `json:"headers"` - Error *rpcError `json:"error"` - Responses []rpcResponseElement `json:"responses"` - Requests *rpcRequestStats `json:"requests"` - Trailers []rpcMetadata `json:"trailers"` + Headers []RPCMetadata `json:"headers"` + Error *RPCError `json:"error"` + Responses []RPCResponseElement `json:"responses"` + Requests *RPCRequestStats `json:"requests"` + Trailers []RPCMetadata `json:"trailers"` } -func (*rpcResult) OnResolveMethod(*desc.MethodDescriptor) {} +func (*RPCResult) OnResolveMethod(*desc.MethodDescriptor) {} -func (*rpcResult) OnSendHeaders(metadata.MD) {} +func (*RPCResult) OnSendHeaders(metadata.MD) {} -func (r *rpcResult) OnReceiveHeaders(md metadata.MD) { +func (r *RPCResult) OnReceiveHeaders(md metadata.MD) { r.Headers = responseMetadata(md) } -func (r *rpcResult) OnReceiveResponse(m proto.Message) { +func (r *RPCResult) OnReceiveResponse(m proto.Message) { r.Responses = append(r.Responses, responseToJSON(r.descSource, m, r.emitDefaults)) } -func (r *rpcResult) OnReceiveTrailers(stat *status.Status, md metadata.MD) { +func (r *RPCResult) OnReceiveTrailers(stat *status.Status, md metadata.MD) { r.Trailers = responseMetadata(md) r.Error = toRpcError(r.descSource, stat, r.emitDefaults) } -func responseMetadata(md metadata.MD) []rpcMetadata { +func responseMetadata(md metadata.MD) []RPCMetadata { keys := make([]string, 0, len(md)) for k := range md { keys = append(keys, k) } sort.Strings(keys) - ret := make([]rpcMetadata, 0, len(md)) + ret := make([]RPCMetadata, 0, len(md)) for _, k := range keys { vals := md[k] for _, v := range vals { if strings.HasSuffix(k, "-bin") { v = base64.StdEncoding.EncodeToString([]byte(v)) } - ret = append(ret, rpcMetadata{Name: k, Value: v}) + ret = append(ret, RPCMetadata{Name: k, Value: v}) } } return ret } -func toRpcError(descSource grpcurl.DescriptorSource, stat *status.Status, emitDefaults bool) *rpcError { +func toRpcError(descSource grpcurl.DescriptorSource, stat *status.Status, emitDefaults bool) *RPCError { if stat.Code() == codes.OK { return nil } details := stat.Proto().Details - msgs := make([]rpcResponseElement, len(details)) + msgs := make([]RPCResponseElement, len(details)) for i, d := range details { msgs[i] = responseToJSON(descSource, d, emitDefaults) } - return &rpcError{ + return &RPCError{ Code: uint32(stat.Code()), Name: stat.Code().String(), Message: stat.Message(), @@ -605,12 +646,12 @@ func toRpcError(descSource grpcurl.DescriptorSource, stat *status.Status, emitDe } } -func responseToJSON(descSource grpcurl.DescriptorSource, msg proto.Message, emitDefaults bool) rpcResponseElement { +func responseToJSON(descSource grpcurl.DescriptorSource, msg proto.Message, emitDefaults bool) RPCResponseElement { anyResolver := grpcurl.AnyResolverFromDescriptorSourceWithFallback(descSource) jsm := jsonpb.Marshaler{EmitDefaults: emitDefaults, OrigName: true, Indent: " ", AnyResolver: anyResolver} var b bytes.Buffer if err := jsm.Marshal(&b, msg); err == nil { - return rpcResponseElement{Data: json.RawMessage(b.Bytes())} + return RPCResponseElement{Data: json.RawMessage(b.Bytes())} } else { b, err := json.Marshal(err.Error()) if err != nil { @@ -618,6 +659,6 @@ func responseToJSON(descSource grpcurl.DescriptorSource, msg proto.Message, emit // should never happen... here's a dumb fallback b = []byte(strconv.Quote(err.Error())) } - return rpcResponseElement{Data: b, IsError: true} + return RPCResponseElement{Data: b, IsError: true} } } diff --git a/standalone/opts.go b/standalone/opts.go index 6f5c99b..80ff68f 100644 --- a/standalone/opts.go +++ b/standalone/opts.go @@ -6,6 +6,8 @@ import ( "html/template" "io" "path" + + "github.com/fullstorydev/grpcui" ) // WebFormContainerTemplateData is the param type for templates that embed the webform HTML. @@ -219,6 +221,13 @@ func WithClientDebug(debug bool) HandlerOption { }) } +// WithMiddlewares adds middlewares to be called before/after each request. +func WithMiddlewares(mws ...grpcui.Middleware) HandlerOption { + return optFunc(func(opts *handlerOptions) { + opts.middlewares = mws + }) +} + // optFunc implements HandlerOption type optFunc func(opts *handlerOptions) @@ -239,6 +248,7 @@ type handlerOptions struct { emitDefaults bool invokeVerbosity int debug *bool + middlewares []grpcui.Middleware } func (opts *handlerOptions) addlServedResources() []*resource { diff --git a/standalone/standalone.go b/standalone/standalone.go index 268d37b..274a33b 100644 --- a/standalone/standalone.go +++ b/standalone/standalone.go @@ -95,6 +95,7 @@ func Handler(ch grpcdynamic.Channel, target string, methods []*desc.MethodDescri PreserveHeaders: uiOpts.preserveHeaders, EmitDefaults: uiOpts.emitDefaults, Verbosity: uiOpts.invokeVerbosity, + Middlewares: uiOpts.middlewares, } rpcInvokeHandler := http.StripPrefix("/invoke", grpcui.RPCInvokeHandlerWithOptions(ch, methods, invokeOpts)) mux.HandleFunc("/invoke/", func(w http.ResponseWriter, r *http.Request) {