diff --git a/http.go b/http.go index ca0613d..f5b6580 100644 --- a/http.go +++ b/http.go @@ -28,6 +28,8 @@ type Request struct { body []byte w requestBodyWriter + bodyStream io.Reader + uri URI parsedURI bool @@ -132,6 +134,25 @@ func (resp *Response) SendFile(path string) error { return nil } +// SetBodyStream sets request body stream and, optionally body size. +// +// If bodySize is >= 0, then the bodyStream must provide exactly bodySize bytes +// before returning io.EOF. +// +// If bodySize < 0, then bodyStream is read until io.EOF. +// +// bodyStream.Close() is called after finishing reading all body data +// if it implements io.Closer. +// +// Note that GET and HEAD requests cannot have body. +// +// See also SetBodyStreamWriter. +func (req *Request) SetBodyStream(bodyStream io.Reader, bodySize int) { + req.ResetBody() + req.bodyStream = bodyStream + req.Header.SetContentLength(bodySize) +} + // SetBodyStream sets response body stream and, optionally body size. // // If bodySize is >= 0, then the bodyStream must provide exactly bodySize bytes @@ -144,12 +165,28 @@ func (resp *Response) SendFile(path string) error { // // See also SetBodyStreamWriter. func (resp *Response) SetBodyStream(bodyStream io.Reader, bodySize int) { - resp.body = resp.body[:0] - resp.closeBodyStream() + resp.ResetBody() resp.bodyStream = bodyStream resp.Header.SetContentLength(bodySize) } +// SetBodyStreamWriter registers the given sw for populating request body. +// +// This function may be used in the following cases: +// +// * if request body is too big (more than 10MB). +// * if request body is streamed from slow external sources. +// * if request body must be streamed to the server in chunks +// (aka `http client push`). +// +// Note that GET and HEAD requests cannot have body. +// +/// See also SetBodyStream. +func (req *Request) SetBodyStreamWriter(sw StreamWriter) { + sr := NewStreamReader(sw) + req.SetBodyStream(sr, -1) +} + // SetBodyStreamWriter registers the given sw for populating response body. // // This function may be used in the following cases: @@ -158,6 +195,8 @@ func (resp *Response) SetBodyStream(bodyStream io.Reader, bodySize int) { // * if response body is streamed from slow external sources. // * if response body must be streamed to the client in chunks // (aka `http server push`). +// +// See also SetBodyStream. func (resp *Response) SetBodyStreamWriter(sw StreamWriter) { sr := NewStreamReader(sw) resp.SetBodyStream(sr, -1) @@ -199,6 +238,15 @@ func (w *requestBodyWriter) Write(p []byte) (int, error) { // Body returns response body. func (resp *Response) Body() []byte { + if resp.bodyStream != nil { + var w ByteBuffer + _, err := io.Copy(&w, resp.bodyStream) + resp.closeBodyStream() + if err != nil { + return []byte(err.Error()) + } + return w.B + } return resp.body } @@ -246,6 +294,20 @@ func (resp *Response) BodyInflate() ([]byte, error) { return b, nil } +// BodyWriteTo writes request body to w. +func (req *Request) BodyWriteTo(w io.Writer) error { + if req.bodyStream != nil { + _, err := io.Copy(w, req.bodyStream) + req.closeBodyStream() + return err + } + if req.onlyMultipartForm() { + return WriteMultipartForm(w, req.multipartForm, req.multipartFormBoundary) + } + _, err := w.Write(req.body) + return err +} + // BodyWriteTo writes response body to w. func (resp *Response) BodyWriteTo(w io.Writer) error { if resp.bodyStream != nil { @@ -289,6 +351,15 @@ func (resp *Response) ResetBody() { // Body returns request body. func (req *Request) Body() []byte { + if req.bodyStream != nil { + var w ByteBuffer + _, err := io.Copy(&w, req.bodyStream) + req.closeBodyStream() + if err != nil { + return []byte(err.Error()) + } + return w.B + } if req.onlyMultipartForm() { body, err := marshalMultipartForm(req.multipartForm, req.multipartFormBoundary) if err != nil { @@ -299,42 +370,38 @@ func (req *Request) Body() []byte { return req.body } -// BodyWriteTo writes request body to w. -func (req *Request) BodyWriteTo(w io.Writer) error { - if req.onlyMultipartForm() { - return WriteMultipartForm(w, req.multipartForm, req.multipartFormBoundary) - } - _, err := w.Write(req.body) - return err -} - // AppendBody appends p to request body. func (req *Request) AppendBody(p []byte) { req.RemoveMultipartFormFiles() + req.closeBodyStream() req.body = append(req.body, p...) } // AppendBodyString appends s to request body. func (req *Request) AppendBodyString(s string) { req.RemoveMultipartFormFiles() + req.closeBodyStream() req.body = append(req.body, s...) } // SetBody sets request body. func (req *Request) SetBody(body []byte) { req.RemoveMultipartFormFiles() + req.closeBodyStream() req.body = append(req.body[:0], body...) } // SetBodyString sets request body. func (req *Request) SetBodyString(body string) { req.RemoveMultipartFormFiles() + req.closeBodyStream() req.body = append(req.body[:0], body...) } // ResetBody resets request body. func (req *Request) ResetBody() { req.RemoveMultipartFormFiles() + req.closeBodyStream() req.body = req.body[:0] } @@ -518,6 +585,7 @@ func (req *Request) Reset() { } func (req *Request) resetSkipHeader() { + req.closeBodyStream() req.body = req.body[:0] req.uri.Reset() req.parsedURI = false @@ -809,6 +877,10 @@ func (req *Request) Write(w *bufio.Writer) error { req.Header.SetRequestURIBytes(uri.RequestURI()) } + if req.bodyStream != nil { + return req.writeBodyStream(w) + } + body := req.body var err error if req.onlyMultipartForm() { @@ -952,53 +1024,110 @@ func (resp *Response) deflateBody(level int) error { // // Write doesn't flush response to w for performance reasons. func (resp *Response) Write(w *bufio.Writer) error { - var err error sendBody := !resp.mustSkipBody() if resp.bodyStream != nil { - contentLength := resp.Header.ContentLength() - if contentLength < 0 { - lrSize := limitedReaderSize(resp.bodyStream) - if lrSize >= 0 { - contentLength = int(lrSize) - if int64(contentLength) != lrSize { - contentLength = -1 - } - } - } - if contentLength >= 0 { - if err = resp.Header.Write(w); err != nil { - return err - } - if sendBody { - if err = writeBodyFixedSize(w, resp.bodyStream, int64(contentLength)); err != nil { - return err - } - } - } else { - resp.Header.SetContentLength(-1) - if err = resp.Header.Write(w); err != nil { - return err - } - if sendBody { - if err = writeBodyChunked(w, resp.bodyStream); err != nil { - return err - } - } - } - return resp.closeBodyStream() + return resp.writeBodyStream(w, sendBody) } bodyLen := len(resp.body) if sendBody || bodyLen > 0 { resp.Header.SetContentLength(bodyLen) } - if err = resp.Header.Write(w); err != nil { + if err := resp.Header.Write(w); err != nil { return err } if sendBody { - _, err = w.Write(resp.body) + if _, err := w.Write(resp.body); err != nil { + return err + } } + return nil +} + +func (req *Request) writeBodyStream(w *bufio.Writer) error { + var err error + + contentLength := req.Header.ContentLength() + if contentLength < 0 { + lrSize := limitedReaderSize(req.bodyStream) + if lrSize >= 0 { + contentLength = int(lrSize) + if int64(contentLength) != lrSize { + contentLength = -1 + } + if contentLength >= 0 { + req.Header.SetContentLength(contentLength) + } + } + } + if contentLength >= 0 { + if err = req.Header.Write(w); err != nil { + return err + } + if err = writeBodyFixedSize(w, req.bodyStream, int64(contentLength)); err != nil { + return err + } + } else { + req.Header.SetContentLength(-1) + if err = req.Header.Write(w); err != nil { + return err + } + if err = writeBodyChunked(w, req.bodyStream); err != nil { + return err + } + } + return req.closeBodyStream() +} + +func (resp *Response) writeBodyStream(w *bufio.Writer, sendBody bool) error { + var err error + + contentLength := resp.Header.ContentLength() + if contentLength < 0 { + lrSize := limitedReaderSize(resp.bodyStream) + if lrSize >= 0 { + contentLength = int(lrSize) + if int64(contentLength) != lrSize { + contentLength = -1 + } + if contentLength >= 0 { + resp.Header.SetContentLength(contentLength) + } + } + } + if contentLength >= 0 { + if err = resp.Header.Write(w); err != nil { + return err + } + if sendBody { + if err = writeBodyFixedSize(w, resp.bodyStream, int64(contentLength)); err != nil { + return err + } + } + } else { + resp.Header.SetContentLength(-1) + if err = resp.Header.Write(w); err != nil { + return err + } + if sendBody { + if err = writeBodyChunked(w, resp.bodyStream); err != nil { + return err + } + } + } + return resp.closeBodyStream() +} + +func (req *Request) closeBodyStream() error { + if req.bodyStream == nil { + return nil + } + var err error + if bsc, ok := req.bodyStream.(io.Closer); ok { + err = bsc.Close() + } + req.bodyStream = nil return err } @@ -1110,7 +1239,7 @@ func writeBodyFixedSize(w *bufio.Writer, r io.Reader, size int64) error { } if n != size && err == nil { - err = fmt.Errorf("copied %d bytes from response body stream instead of %d bytes", n, size) + err = fmt.Errorf("copied %d bytes from body stream instead of %d bytes", n, size) } return err } diff --git a/http_test.go b/http_test.go index 4a4c556..c1d0c6a 100644 --- a/http_test.go +++ b/http_test.go @@ -671,12 +671,28 @@ func TestRequestWriteRequestURINoHost(t *testing.T) { } } +func TestSetRequestBodyStreamFixedSize(t *testing.T) { + testSetRequestBodyStream(t, "a", false) + testSetRequestBodyStream(t, string(createFixedBody(4097)), false) + testSetRequestBodyStream(t, string(createFixedBody(100500)), false) +} + func TestSetResponseBodyStreamFixedSize(t *testing.T) { testSetResponseBodyStream(t, "a", false) testSetResponseBodyStream(t, string(createFixedBody(4097)), false) testSetResponseBodyStream(t, string(createFixedBody(100500)), false) } +func TestSetRequestBodyStreamChunked(t *testing.T) { + testSetRequestBodyStream(t, "", true) + + body := "foobar baz aaa bbb ccc" + testSetRequestBodyStream(t, body, true) + + body = string(createFixedBody(10001)) + testSetRequestBodyStream(t, body, true) +} + func TestSetResponseBodyStreamChunked(t *testing.T) { testSetResponseBodyStream(t, "", true) @@ -687,6 +703,36 @@ func TestSetResponseBodyStreamChunked(t *testing.T) { testSetResponseBodyStream(t, body, true) } +func testSetRequestBodyStream(t *testing.T, body string, chunked bool) { + var req Request + req.Header.SetHost("foobar.com") + req.Header.SetMethod("POST") + + bodySize := len(body) + if chunked { + bodySize = -1 + } + req.SetBodyStream(bytes.NewBufferString(body), bodySize) + + var w bytes.Buffer + bw := bufio.NewWriter(&w) + if err := req.Write(bw); err != nil { + t.Fatalf("unexpected error when writing request: %s. body=%q", err, body) + } + if err := bw.Flush(); err != nil { + t.Fatalf("unexpected error when flushing request: %s. body=%q", err, body) + } + + var req1 Request + br := bufio.NewReader(&w) + if err := req1.Read(br); err != nil { + t.Fatalf("unexpected error when reading request: %s. body=%q", err, body) + } + if string(req1.Body()) != body { + t.Fatalf("unexpected body %q. Expecting %q", req1.Body(), body) + } +} + func testSetResponseBodyStream(t *testing.T, body string, chunked bool) { var resp Response bodySize := len(body)