diff --git a/net/wasihttp/adapter.go b/net/wasihttp/adapter.go index d15d690..74ed0f7 100644 --- a/net/wasihttp/adapter.go +++ b/net/wasihttp/adapter.go @@ -108,6 +108,10 @@ func (row *responseOutparamWriter) reconcile() { } func (row *responseOutparamWriter) Close() error { + if row.stream == nil { + return nil + } + row.stream.BlockingFlush() row.stream.ResourceDrop() @@ -138,14 +142,12 @@ func (row *responseOutparamWriter) Close() error { // convert the ResponseOutparam to http.ResponseWriter func NewHttpResponseWriter(out types.ResponseOutparam) *responseOutparamWriter { - row := &responseOutparamWriter{ + return &responseOutparamWriter{ outparam: out, httpHeaders: http.Header{}, wasiHeaders: types.NewFields(), statuscode: http.StatusOK, } - - return row } // convert the IncomingRequest to http.Request @@ -167,7 +169,16 @@ func NewHttpRequest(ir IncomingRequest) (req *http.Request, err error) { body, trailers, err := NewIncomingBodyTrailer(ir) if err != nil { - return nil, fmt.Errorf("failed to consume incoming request %s", err) + switch method { + case http.MethodGet, + http.MethodHead, + http.MethodDelete, + http.MethodConnect, + http.MethodOptions, + http.MethodTrace: + default: + return nil, fmt.Errorf("failed to consume incoming request: %w", err) + } } url := fmt.Sprintf("http://%s%s", authority, pathWithQuery) @@ -188,23 +199,23 @@ func NewHttpRequest(ir IncomingRequest) (req *http.Request, err error) { func methodToString(m types.Method) (string, error) { if m.Connect() { - return "CONNECT", nil + return http.MethodConnect, nil } else if m.Delete() { - return "DELETE", nil + return http.MethodDelete, nil } else if m.Get() { - return "GET", nil + return http.MethodGet, nil } else if m.Head() { - return "HEAD", nil + return http.MethodHead, nil } else if m.Options() { - return "OPTIONS", nil + return http.MethodOptions, nil } else if m.Patch() { - return "PATCH", nil + return http.MethodPatch, nil } else if m.Post() { - return "POST", nil + return http.MethodPost, nil } else if m.Put() { - return "PUT", nil + return http.MethodPut, nil } else if m.Trace() { - return "TRACE", nil + return http.MethodTrace, nil } else if other := m.Other(); other != nil { return *other, fmt.Errorf("unknown http method '%s'", *other) } diff --git a/net/wasihttp/server.go b/net/wasihttp/server.go index 221b070..beac19a 100644 --- a/net/wasihttp/server.go +++ b/net/wasihttp/server.go @@ -5,6 +5,7 @@ import ( "net/http" "os" + "github.com/bytecodealliance/wasm-tools-go/cm" incominghandler "go.wasmcloud.dev/component/gen/wasi/http/incoming-handler" "go.wasmcloud.dev/component/gen/wasi/http/types" ) @@ -31,10 +32,14 @@ func HandleFunc(h http.HandlerFunc) { func wasiHandle(request types.IncomingRequest, responseOut types.ResponseOutparam) { httpReq, err := NewHttpRequest(request) if err != nil { - fmt.Fprintf(os.Stderr, "failed to convert wasi/http/types.IncomingRequest to http.Request: %s\n", err) + types.ResponseOutparamSet(responseOut, cm.Err[cm.Result[types.ErrorCodeShape, types.OutgoingResponse, types.ErrorCode]]( + types.ErrorCodeInternalError(cm.Some(err.Error()))), + ) return } - defer httpReq.Body.Close() + if httpReq.Body != nil { + defer httpReq.Body.Close() + } httpRes := NewHttpResponseWriter(responseOut) defer httpRes.Close() diff --git a/net/wasihttp/streams.go b/net/wasihttp/streams.go index adf09ed..6e26832 100644 --- a/net/wasihttp/streams.go +++ b/net/wasihttp/streams.go @@ -1,6 +1,7 @@ package wasihttp import ( + "errors" "fmt" "io" "net/http" @@ -72,13 +73,12 @@ func (r *inputStreamReader) parseTrailers() { func (r *inputStreamReader) Read(p []byte) (n int, err error) { readResult := r.stream.BlockingRead(uint64(len(p))) - if readResult.IsErr() { - readErr := readResult.Err() - if readErr.Closed() { + if err := readResult.Err(); err != nil { + if err.Closed() { r.trailerOnce.Do(r.parseTrailers) return 0, io.EOF } - return 0, fmt.Errorf("failed to read from InputStream %s", readErr.LastOperationFailed().ToDebugString()) + return 0, fmt.Errorf("failed to read from InputStream %s", err.LastOperationFailed().ToDebugString()) } readList := *readResult.OK() @@ -90,12 +90,12 @@ func NewIncomingBodyTrailer(consumer BodyConsumer) (io.ReadCloser, http.Header, trailers := http.Header{} consumeResult := consumer.Consume() if consumeResult.IsErr() { - return nil, nil, fmt.Errorf("failed to consume incoming request %s", *consumeResult.Err()) + return nil, nil, errors.New("failed to consume incoming request") } body := *consumeResult.OK() streamResult := body.Stream() if streamResult.IsErr() { - return nil, nil, fmt.Errorf("failed to consume incoming requests's stream %s", streamResult.Err()) + return nil, nil, errors.New("failed to consume incoming request body stream") } return &inputStreamReader{ consumer: consumer, @@ -113,7 +113,7 @@ type outputStreamReader struct { func NewOutgoingBody(body types.OutgoingBody) (io.WriteCloser, error) { stream := body.Write() if stream.IsErr() { - return nil, fmt.Errorf("failed to acquire resource handle to request body: %s", stream.Err()) + return nil, errors.New("failed to acquire resource handle to request body") } return &outputStreamReader{ body: body, @@ -130,12 +130,11 @@ func (r *outputStreamReader) Close() error { func (r *outputStreamReader) Write(p []byte) (n int, err error) { contents := cm.ToList(p) writeResult := r.stream.BlockingWriteAndFlush(contents) - if writeResult.IsErr() { - if writeResult.Err().Closed() { + if err := writeResult.Err(); err != nil { + if err.Closed() { return 0, io.EOF } - - return 0, fmt.Errorf("failed to write to response body's stream: %s", writeResult.Err().LastOperationFailed().ToDebugString()) + return 0, fmt.Errorf("failed to write to response body's stream: %s", err.LastOperationFailed().ToDebugString()) } return len(p), nil }