Skip to content

Commit

Permalink
add DisableHTTPHeader field
Browse files Browse the repository at this point in the history
  • Loading branch information
andot committed Mar 7, 2022
1 parent bfab4ee commit 6ceaaab
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 33 deletions.
27 changes: 16 additions & 11 deletions rpc/http/fasthttp/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
| |
| rpc/http/fasthttp/transport.go |
| |
| LastModified: Mar 6, 2022 |
| LastModified: Mar 7, 2022 |
| Author: Ma Bingyao <[email protected]> |
| |
\*________________________________________________________*/
Expand All @@ -26,10 +26,11 @@ import (
)

type Transport struct {
Header http.Header
FastHTTPClient fasthttp.Client
compression bool
keepAlive bool
DisableHTTPHeader bool
Header http.Header
FastHTTPClient fasthttp.Client
compression bool
keepAlive bool
*cookieManager
}

Expand All @@ -56,11 +57,13 @@ func (trans *Transport) Transport(ctx context.Context, request []byte) (response
req.Header.SetMethod("POST")
req.SetRequestURI(clientContext.URL.String())
req.SetBody(request)
if trans.Header != nil {
addRequestHeader(&req.Header, trans.Header)
}
if header, ok := clientContext.Items().GetInterface("httpRequestHeaders").(http.Header); ok {
addRequestHeader(&req.Header, header)
if !trans.DisableHTTPHeader {
if trans.Header != nil {
addRequestHeader(&req.Header, trans.Header)
}
if header, ok := clientContext.Items().GetInterface("httpRequestHeaders").(http.Header); ok {
addRequestHeader(&req.Header, header)
}
}
if trans.keepAlive {
req.Header.Set("Connection", "keep-alive")
Expand Down Expand Up @@ -90,7 +93,9 @@ func (trans *Transport) Transport(ctx context.Context, request []byte) (response
clientContext.Items().Set("httpStatusText", string(resp.Header.StatusMessage()))
switch resp.Header.StatusCode() {
case fasthttp.StatusOK:
clientContext.Items().Set("httpResponseHeaders", getResponseHeader(&resp.Header))
if !trans.DisableHTTPHeader {
clientContext.Items().Set("httpResponseHeaders", getResponseHeader(&resp.Header))
}
return resp.SwapBody(nil), nil
case fasthttp.StatusRequestEntityTooLarge:
return nil, core.ErrRequestEntityTooLarge
Expand Down
35 changes: 22 additions & 13 deletions rpc/http/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
| |
| rpc/http/handler.go |
| |
| LastModified: Mar 6, 2022 |
| LastModified: Mar 7, 2022 |
| Author: Ma Bingyao <[email protected]> |
| |
\*________________________________________________________*/
Expand Down Expand Up @@ -36,6 +36,7 @@ type Handler struct {
P3P bool
GET bool
CrossDomain bool
DisableHTTPHeader bool
Header http.Header
AccessControlAllowOrigins map[string]bool
LastModified string
Expand Down Expand Up @@ -214,11 +215,13 @@ func (h *Handler) sendHeader(serviceContext *core.ServiceContext, response http.
responseHeader.Set("Access-Control-Allow-Origin", "*")
}
}
if h.Header != nil {
addHeader(responseHeader, h.Header)
}
if header, ok := serviceContext.Items().GetInterface("httpResponseHeaders").(http.Header); ok {
addHeader(responseHeader, header)
if !h.DisableHTTPHeader {
if h.Header != nil {
addHeader(responseHeader, h.Header)
}
if header, ok := serviceContext.Items().GetInterface("httpResponseHeaders").(http.Header); ok {
addHeader(responseHeader, header)
}
}
if code := serviceContext.Items().GetInt("httpStatusCode"); code != 0 {
response.WriteHeader(code)
Expand All @@ -229,7 +232,9 @@ func (h *Handler) getServiceContext(response http.ResponseWriter, request *http.
serviceContext := core.NewServiceContext(h.Service)
serviceContext.Items().Set("request", request)
serviceContext.Items().Set("response", response)
serviceContext.Items().Set("httpRequestHeaders", request.Header)
if !h.DisableHTTPHeader {
serviceContext.Items().Set("httpRequestHeaders", request.Header)
}
serviceContext.LocalAddr, _ = net.ResolveTCPAddr("tcp", request.Host)
serviceContext.RemoteAddr, _ = net.ResolveTCPAddr("tcp", request.RemoteAddr)
serviceContext.Handler = h
Expand Down Expand Up @@ -338,11 +343,13 @@ func (h *Handler) sendFastHTTPHeader(serviceContext *core.ServiceContext, ctx *f
ctx.Response.Header.Set("Access-Control-Allow-Origin", "*")
}
}
if h.Header != nil {
addResponseHeader(&ctx.Response.Header, h.Header)
}
if header, ok := serviceContext.Items().GetInterface("httpResponseHeaders").(http.Header); ok {
addResponseHeader(&ctx.Response.Header, header)
if !h.DisableHTTPHeader {
if h.Header != nil {
addResponseHeader(&ctx.Response.Header, h.Header)
}
if header, ok := serviceContext.Items().GetInterface("httpResponseHeaders").(http.Header); ok {
addResponseHeader(&ctx.Response.Header, header)
}
}
if code := serviceContext.Items().GetInt("httpStatusCode"); code != 0 {
ctx.SetStatusCode(code)
Expand All @@ -352,7 +359,9 @@ func (h *Handler) sendFastHTTPHeader(serviceContext *core.ServiceContext, ctx *f
func (h *Handler) getFastHTTPServiceContext(ctx *fasthttp.RequestCtx) *core.ServiceContext {
serviceContext := core.NewServiceContext(h.Service)
serviceContext.Items().Set("requestCtx", ctx)
serviceContext.Items().Set("httpRequestHeaders", getRequestHeader(&ctx.Request.Header))
if !h.DisableHTTPHeader {
serviceContext.Items().Set("httpRequestHeaders", getRequestHeader(&ctx.Request.Header))
}
serviceContext.LocalAddr, _ = net.ResolveTCPAddr("tcp", convert.ToUnsafeString(ctx.Host()))
serviceContext.RemoteAddr = ctx.RemoteAddr()
serviceContext.Handler = h
Expand Down
23 changes: 14 additions & 9 deletions rpc/http/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
| |
| rpc/http/transport.go |
| |
| LastModified: May 5, 2021 |
| LastModified: Mar 7, 2022 |
| Author: Ma Bingyao <[email protected]> |
| |
\*________________________________________________________*/
Expand All @@ -28,8 +28,9 @@ import (
)

type Transport struct {
Header http.Header
HTTPClient http.Client
DisableHTTPHeader bool
Header http.Header
HTTPClient http.Client
}

func (trans *Transport) Transport(ctx context.Context, request []byte) ([]byte, error) {
Expand All @@ -38,11 +39,13 @@ func (trans *Transport) Transport(ctx context.Context, request []byte) ([]byte,
if err != nil {
return nil, err
}
if trans.Header != nil {
addHeader(req.Header, trans.Header)
}
if header, ok := clientContext.Items().GetInterface("httpRequestHeaders").(http.Header); ok {
addHeader(req.Header, header)
if !trans.DisableHTTPHeader {
if trans.Header != nil {
addHeader(req.Header, trans.Header)
}
if header, ok := clientContext.Items().GetInterface("httpRequestHeaders").(http.Header); ok {
addHeader(req.Header, header)
}
}
var resp *http.Response
resp, err = trans.HTTPClient.Do(req)
Expand All @@ -54,7 +57,9 @@ func (trans *Transport) Transport(ctx context.Context, request []byte) ([]byte,
clientContext.Items().Set("httpStatusText", http.StatusText(resp.StatusCode))
switch resp.StatusCode {
case http.StatusOK:
clientContext.Items().Set("httpResponseHeaders", resp.Header)
if !trans.DisableHTTPHeader {
clientContext.Items().Set("httpResponseHeaders", resp.Header)
}
return readAll(resp.Body, resp.ContentLength)
case http.StatusRequestEntityTooLarge:
return nil, core.ErrRequestEntityTooLarge
Expand Down

0 comments on commit 6ceaaab

Please sign in to comment.