Added SetBodyStream* to Request. This allows streaming multi-GB data in request bodies

This commit is contained in:
Aliaksandr Valialkin
2016-02-12 20:52:29 +02:00
parent dde91b5c5b
commit 77a12cbf68
2 changed files with 221 additions and 46 deletions
+175 -46
View File
@@ -28,6 +28,8 @@ type Request struct {
body []byte
w requestBodyWriter
bodyStream io.Reader
uri URI
parsedURI bool
@@ -132,6 +134,25 @@ func (resp *Response) SendFile(path string) error {
return nil
}
// SetBodyStream sets request body stream and, optionally body size.
//
// If bodySize is >= 0, then the bodyStream must provide exactly bodySize bytes
// before returning io.EOF.
//
// If bodySize < 0, then bodyStream is read until io.EOF.
//
// bodyStream.Close() is called after finishing reading all body data
// if it implements io.Closer.
//
// Note that GET and HEAD requests cannot have body.
//
// See also SetBodyStreamWriter.
func (req *Request) SetBodyStream(bodyStream io.Reader, bodySize int) {
req.ResetBody()
req.bodyStream = bodyStream
req.Header.SetContentLength(bodySize)
}
// SetBodyStream sets response body stream and, optionally body size.
//
// If bodySize is >= 0, then the bodyStream must provide exactly bodySize bytes
@@ -144,12 +165,28 @@ func (resp *Response) SendFile(path string) error {
//
// See also SetBodyStreamWriter.
func (resp *Response) SetBodyStream(bodyStream io.Reader, bodySize int) {
resp.body = resp.body[:0]
resp.closeBodyStream()
resp.ResetBody()
resp.bodyStream = bodyStream
resp.Header.SetContentLength(bodySize)
}
// SetBodyStreamWriter registers the given sw for populating request body.
//
// This function may be used in the following cases:
//
// * if request body is too big (more than 10MB).
// * if request body is streamed from slow external sources.
// * if request body must be streamed to the server in chunks
// (aka `http client push`).
//
// Note that GET and HEAD requests cannot have body.
//
/// See also SetBodyStream.
func (req *Request) SetBodyStreamWriter(sw StreamWriter) {
sr := NewStreamReader(sw)
req.SetBodyStream(sr, -1)
}
// SetBodyStreamWriter registers the given sw for populating response body.
//
// This function may be used in the following cases:
@@ -158,6 +195,8 @@ func (resp *Response) SetBodyStream(bodyStream io.Reader, bodySize int) {
// * if response body is streamed from slow external sources.
// * if response body must be streamed to the client in chunks
// (aka `http server push`).
//
// See also SetBodyStream.
func (resp *Response) SetBodyStreamWriter(sw StreamWriter) {
sr := NewStreamReader(sw)
resp.SetBodyStream(sr, -1)
@@ -199,6 +238,15 @@ func (w *requestBodyWriter) Write(p []byte) (int, error) {
// Body returns response body.
func (resp *Response) Body() []byte {
if resp.bodyStream != nil {
var w ByteBuffer
_, err := io.Copy(&w, resp.bodyStream)
resp.closeBodyStream()
if err != nil {
return []byte(err.Error())
}
return w.B
}
return resp.body
}
@@ -246,6 +294,20 @@ func (resp *Response) BodyInflate() ([]byte, error) {
return b, nil
}
// BodyWriteTo writes request body to w.
func (req *Request) BodyWriteTo(w io.Writer) error {
if req.bodyStream != nil {
_, err := io.Copy(w, req.bodyStream)
req.closeBodyStream()
return err
}
if req.onlyMultipartForm() {
return WriteMultipartForm(w, req.multipartForm, req.multipartFormBoundary)
}
_, err := w.Write(req.body)
return err
}
// BodyWriteTo writes response body to w.
func (resp *Response) BodyWriteTo(w io.Writer) error {
if resp.bodyStream != nil {
@@ -289,6 +351,15 @@ func (resp *Response) ResetBody() {
// Body returns request body.
func (req *Request) Body() []byte {
if req.bodyStream != nil {
var w ByteBuffer
_, err := io.Copy(&w, req.bodyStream)
req.closeBodyStream()
if err != nil {
return []byte(err.Error())
}
return w.B
}
if req.onlyMultipartForm() {
body, err := marshalMultipartForm(req.multipartForm, req.multipartFormBoundary)
if err != nil {
@@ -299,42 +370,38 @@ func (req *Request) Body() []byte {
return req.body
}
// BodyWriteTo writes request body to w.
func (req *Request) BodyWriteTo(w io.Writer) error {
if req.onlyMultipartForm() {
return WriteMultipartForm(w, req.multipartForm, req.multipartFormBoundary)
}
_, err := w.Write(req.body)
return err
}
// AppendBody appends p to request body.
func (req *Request) AppendBody(p []byte) {
req.RemoveMultipartFormFiles()
req.closeBodyStream()
req.body = append(req.body, p...)
}
// AppendBodyString appends s to request body.
func (req *Request) AppendBodyString(s string) {
req.RemoveMultipartFormFiles()
req.closeBodyStream()
req.body = append(req.body, s...)
}
// SetBody sets request body.
func (req *Request) SetBody(body []byte) {
req.RemoveMultipartFormFiles()
req.closeBodyStream()
req.body = append(req.body[:0], body...)
}
// SetBodyString sets request body.
func (req *Request) SetBodyString(body string) {
req.RemoveMultipartFormFiles()
req.closeBodyStream()
req.body = append(req.body[:0], body...)
}
// ResetBody resets request body.
func (req *Request) ResetBody() {
req.RemoveMultipartFormFiles()
req.closeBodyStream()
req.body = req.body[:0]
}
@@ -518,6 +585,7 @@ func (req *Request) Reset() {
}
func (req *Request) resetSkipHeader() {
req.closeBodyStream()
req.body = req.body[:0]
req.uri.Reset()
req.parsedURI = false
@@ -809,6 +877,10 @@ func (req *Request) Write(w *bufio.Writer) error {
req.Header.SetRequestURIBytes(uri.RequestURI())
}
if req.bodyStream != nil {
return req.writeBodyStream(w)
}
body := req.body
var err error
if req.onlyMultipartForm() {
@@ -952,53 +1024,110 @@ func (resp *Response) deflateBody(level int) error {
//
// Write doesn't flush response to w for performance reasons.
func (resp *Response) Write(w *bufio.Writer) error {
var err error
sendBody := !resp.mustSkipBody()
if resp.bodyStream != nil {
contentLength := resp.Header.ContentLength()
if contentLength < 0 {
lrSize := limitedReaderSize(resp.bodyStream)
if lrSize >= 0 {
contentLength = int(lrSize)
if int64(contentLength) != lrSize {
contentLength = -1
}
}
}
if contentLength >= 0 {
if err = resp.Header.Write(w); err != nil {
return err
}
if sendBody {
if err = writeBodyFixedSize(w, resp.bodyStream, int64(contentLength)); err != nil {
return err
}
}
} else {
resp.Header.SetContentLength(-1)
if err = resp.Header.Write(w); err != nil {
return err
}
if sendBody {
if err = writeBodyChunked(w, resp.bodyStream); err != nil {
return err
}
}
}
return resp.closeBodyStream()
return resp.writeBodyStream(w, sendBody)
}
bodyLen := len(resp.body)
if sendBody || bodyLen > 0 {
resp.Header.SetContentLength(bodyLen)
}
if err = resp.Header.Write(w); err != nil {
if err := resp.Header.Write(w); err != nil {
return err
}
if sendBody {
_, err = w.Write(resp.body)
if _, err := w.Write(resp.body); err != nil {
return err
}
}
return nil
}
func (req *Request) writeBodyStream(w *bufio.Writer) error {
var err error
contentLength := req.Header.ContentLength()
if contentLength < 0 {
lrSize := limitedReaderSize(req.bodyStream)
if lrSize >= 0 {
contentLength = int(lrSize)
if int64(contentLength) != lrSize {
contentLength = -1
}
if contentLength >= 0 {
req.Header.SetContentLength(contentLength)
}
}
}
if contentLength >= 0 {
if err = req.Header.Write(w); err != nil {
return err
}
if err = writeBodyFixedSize(w, req.bodyStream, int64(contentLength)); err != nil {
return err
}
} else {
req.Header.SetContentLength(-1)
if err = req.Header.Write(w); err != nil {
return err
}
if err = writeBodyChunked(w, req.bodyStream); err != nil {
return err
}
}
return req.closeBodyStream()
}
func (resp *Response) writeBodyStream(w *bufio.Writer, sendBody bool) error {
var err error
contentLength := resp.Header.ContentLength()
if contentLength < 0 {
lrSize := limitedReaderSize(resp.bodyStream)
if lrSize >= 0 {
contentLength = int(lrSize)
if int64(contentLength) != lrSize {
contentLength = -1
}
if contentLength >= 0 {
resp.Header.SetContentLength(contentLength)
}
}
}
if contentLength >= 0 {
if err = resp.Header.Write(w); err != nil {
return err
}
if sendBody {
if err = writeBodyFixedSize(w, resp.bodyStream, int64(contentLength)); err != nil {
return err
}
}
} else {
resp.Header.SetContentLength(-1)
if err = resp.Header.Write(w); err != nil {
return err
}
if sendBody {
if err = writeBodyChunked(w, resp.bodyStream); err != nil {
return err
}
}
}
return resp.closeBodyStream()
}
func (req *Request) closeBodyStream() error {
if req.bodyStream == nil {
return nil
}
var err error
if bsc, ok := req.bodyStream.(io.Closer); ok {
err = bsc.Close()
}
req.bodyStream = nil
return err
}
@@ -1110,7 +1239,7 @@ func writeBodyFixedSize(w *bufio.Writer, r io.Reader, size int64) error {
}
if n != size && err == nil {
err = fmt.Errorf("copied %d bytes from response body stream instead of %d bytes", n, size)
err = fmt.Errorf("copied %d bytes from body stream instead of %d bytes", n, size)
}
return err
}
+46
View File
@@ -671,12 +671,28 @@ func TestRequestWriteRequestURINoHost(t *testing.T) {
}
}
func TestSetRequestBodyStreamFixedSize(t *testing.T) {
testSetRequestBodyStream(t, "a", false)
testSetRequestBodyStream(t, string(createFixedBody(4097)), false)
testSetRequestBodyStream(t, string(createFixedBody(100500)), false)
}
func TestSetResponseBodyStreamFixedSize(t *testing.T) {
testSetResponseBodyStream(t, "a", false)
testSetResponseBodyStream(t, string(createFixedBody(4097)), false)
testSetResponseBodyStream(t, string(createFixedBody(100500)), false)
}
func TestSetRequestBodyStreamChunked(t *testing.T) {
testSetRequestBodyStream(t, "", true)
body := "foobar baz aaa bbb ccc"
testSetRequestBodyStream(t, body, true)
body = string(createFixedBody(10001))
testSetRequestBodyStream(t, body, true)
}
func TestSetResponseBodyStreamChunked(t *testing.T) {
testSetResponseBodyStream(t, "", true)
@@ -687,6 +703,36 @@ func TestSetResponseBodyStreamChunked(t *testing.T) {
testSetResponseBodyStream(t, body, true)
}
func testSetRequestBodyStream(t *testing.T, body string, chunked bool) {
var req Request
req.Header.SetHost("foobar.com")
req.Header.SetMethod("POST")
bodySize := len(body)
if chunked {
bodySize = -1
}
req.SetBodyStream(bytes.NewBufferString(body), bodySize)
var w bytes.Buffer
bw := bufio.NewWriter(&w)
if err := req.Write(bw); err != nil {
t.Fatalf("unexpected error when writing request: %s. body=%q", err, body)
}
if err := bw.Flush(); err != nil {
t.Fatalf("unexpected error when flushing request: %s. body=%q", err, body)
}
var req1 Request
br := bufio.NewReader(&w)
if err := req1.Read(br); err != nil {
t.Fatalf("unexpected error when reading request: %s. body=%q", err, body)
}
if string(req1.Body()) != body {
t.Fatalf("unexpected body %q. Expecting %q", req1.Body(), body)
}
}
func testSetResponseBodyStream(t *testing.T, body string, chunked bool) {
var resp Response
bodySize := len(body)