Streaming fixes (#970)

- Allow DisablePreParseMultipartForm in combination with
StreamRequestBody.
- Support streaming into MultipartForm instead of reading the whole body
  first.
- Support calling ctx.PostBody() when streaming is enabled.
This commit is contained in:
Erik Dubbelboer
2021-02-16 21:53:53 +01:00
committed by GitHub
parent 1b61ca2e36
commit 3cd0862fbb
4 changed files with 265 additions and 161 deletions
+65 -77
View File
@@ -3,6 +3,7 @@ package fasthttp
import (
"bufio"
"bytes"
"compress/gzip"
"encoding/base64"
"errors"
"fmt"
@@ -345,6 +346,15 @@ func (req *Request) bodyBytes() []byte {
if req.bodyRaw != nil {
return req.bodyRaw
}
if req.bodyStream != nil {
bodyBuf := req.bodyBuffer()
bodyBuf.Reset()
_, err := copyZeroAlloc(bodyBuf, req.bodyStream)
req.closeBodyStream() //nolint:errcheck
if err != nil {
bodyBuf.SetString(err.Error())
}
}
if req.body == nil {
return nil
}
@@ -630,14 +640,6 @@ func (req *Request) SwapBody(body []byte) []byte {
func (req *Request) Body() []byte {
if req.bodyRaw != nil {
return req.bodyRaw
} else if req.bodyStream != nil {
bodyBuf := req.bodyBuffer()
bodyBuf.Reset()
_, err := copyZeroAlloc(bodyBuf, req.bodyStream)
req.closeBodyStream() //nolint:errcheck
if err != nil {
bodyBuf.SetString(err.Error())
}
} else if req.onlyMultipartForm() {
body, err := marshalMultipartForm(req.multipartForm, req.multipartFormBoundary)
if err != nil {
@@ -814,24 +816,43 @@ func (req *Request) MultipartForm() (*multipart.Form, error) {
return nil, ErrNoMultipartForm
}
var err error
ce := req.Header.peek(strContentEncoding)
body := req.bodyBytes()
if bytes.Equal(ce, strGzip) {
// Do not care about memory usage here.
var err error
if body, err = AppendGunzipBytes(nil, body); err != nil {
return nil, fmt.Errorf("cannot gunzip request body: %s", err)
if req.bodyStream != nil {
bodyStream := req.bodyStream
if bytes.Equal(ce, strGzip) {
// Do not care about memory usage here.
if bodyStream, err = gzip.NewReader(bodyStream); err != nil {
return nil, fmt.Errorf("cannot gunzip request body: %s", err)
}
} else if len(ce) > 0 {
return nil, fmt.Errorf("unsupported Content-Encoding: %q", ce)
}
mr := multipart.NewReader(bodyStream, req.multipartFormBoundary)
req.multipartForm, err = mr.ReadForm(8 * 1024)
if err != nil {
return nil, fmt.Errorf("cannot read multipart/form-data body: %s", err)
}
} else {
body := req.bodyBytes()
if bytes.Equal(ce, strGzip) {
// Do not care about memory usage here.
if body, err = AppendGunzipBytes(nil, body); err != nil {
return nil, fmt.Errorf("cannot gunzip request body: %s", err)
}
} else if len(ce) > 0 {
return nil, fmt.Errorf("unsupported Content-Encoding: %q", ce)
}
req.multipartForm, err = readMultipartForm(bytes.NewReader(body), req.multipartFormBoundary, len(body), len(body))
if err != nil {
return nil, err
}
} else if len(ce) > 0 {
return nil, fmt.Errorf("unsupported Content-Encoding: %q", ce)
}
f, err := readMultipartForm(bytes.NewReader(body), req.multipartFormBoundary, len(body), len(body))
if err != nil {
return nil, err
}
req.multipartForm = f
return f, nil
return req.multipartForm, nil
}
func marshalMultipartForm(f *multipart.Form, boundary string) ([]byte, error) {
@@ -1022,6 +1043,9 @@ func (req *Request) readLimitBody(r *bufio.Reader, maxBodySize int, getOnly bool
}
func (req *Request) readBodyStream(r *bufio.Reader, maxBodySize int, getOnly bool, preParseMultipartForm bool) error {
// Do not reset the request here - the caller must reset it before
// calling this method.
if getOnly && !req.Header.IsGet() {
return ErrGetOnly
}
@@ -1033,39 +1057,7 @@ func (req *Request) readBodyStream(r *bufio.Reader, maxBodySize int, getOnly boo
return nil
}
var err error
contentLength := req.Header.realContentLength()
if contentLength > 0 {
if preParseMultipartForm {
// Pre-read multipart form data of known length.
// This way we limit memory usage for large file uploads, since their contents
// is streamed into temporary files if file size exceeds defaultMaxInMemoryFileSize.
req.multipartFormBoundary = b2s(req.Header.MultipartFormBoundary())
if len(req.multipartFormBoundary) > 0 && len(req.Header.peek(strContentEncoding)) == 0 {
req.multipartForm, err = readMultipartForm(r, req.multipartFormBoundary, contentLength, defaultMaxInMemoryFileSize)
if err != nil {
req.Reset()
}
return err
}
}
}
if contentLength == -2 {
// identity body has no sense for http requests, since
// the end of body is determined by connection close.
// So just ignore request body for requests without
// 'Content-Length' and 'Transfer-Encoding' headers.
req.Header.SetContentLength(0)
return nil
}
bodyBuf := req.bodyBuffer()
bodyBuf.Reset()
req.bodyStream = acquireRequestStream(bodyBuf, r, contentLength)
return nil
return req.ContinueReadBodyStream(r, maxBodySize, preParseMultipartForm)
}
// MayContinue returns true if the request contains
@@ -1170,21 +1162,15 @@ func (req *Request) ContinueReadBodyStream(r *bufio.Reader, maxBodySize int, pre
bodyBuf := req.bodyBuffer()
bodyBuf.Reset()
bodyBuf.B, err = readBodyWithStreaming(r, contentLength, maxBodySize, bodyBuf.B)
bodyBufLen := maxBodySize
if contentLength < maxBodySize {
bodyBufLen = cap(bodyBuf.B)
}
if err != nil {
if err == ErrBodyTooLarge {
req.Header.SetContentLength(contentLength)
req.body = bodyBuf
req.bodyRaw = bodyBuf.B[:bodyBufLen]
req.bodyStream = acquireRequestStream(bodyBuf, r, contentLength)
return nil
}
if err == errChunkedStream {
req.body = bodyBuf
req.bodyRaw = bodyBuf.B[:bodyBufLen]
req.bodyStream = acquireRequestStream(bodyBuf, r, -1)
return nil
}
@@ -1193,7 +1179,6 @@ func (req *Request) ContinueReadBodyStream(r *bufio.Reader, maxBodySize int, pre
}
req.body = bodyBuf
req.bodyRaw = bodyBuf.B[:bodyBufLen]
req.bodyStream = acquireRequestStream(bodyBuf, r, contentLength)
req.Header.SetContentLength(len(bodyBuf.B))
return nil
@@ -1936,24 +1921,27 @@ func readBody(r *bufio.Reader, contentLength int, maxBodySize int, dst []byte) (
var errChunkedStream = errors.New("chunked stream")
func readBodyWithStreaming(r *bufio.Reader, contentLength int, maxBodySize int, dst []byte) (b []byte, err error) {
dst = dst[:0]
switch {
case contentLength >= 0 && maxBodySize >= contentLength:
readN := maxBodySize
if contentLength > 8*1024 {
readN = 8 * 1024
}
b, err = appendBodyFixedSize(r, dst, readN)
case contentLength == -1:
if contentLength == -1 {
// handled in requestStream.Read()
err = errChunkedStream
default:
readN := maxBodySize
if contentLength > 8*1024 {
readN = 8 * 1024
}
return b, errChunkedStream
}
dst = dst[:0]
readN := maxBodySize
if readN > contentLength {
readN = contentLength
}
if readN > 8*1024 {
readN = 8 * 1024
}
if contentLength >= 0 && maxBodySize >= contentLength {
b, err = appendBodyFixedSize(r, dst, readN)
} else {
b, err = readBodyIdentity(r, readN, dst)
}
if err != nil {
return b, err
}
+100 -83
View File
@@ -1073,7 +1073,16 @@ func TestServerServeTLSEmbed(t *testing.T) {
func TestServerMultipartFormDataRequest(t *testing.T) {
t.Parallel()
reqS := `POST /upload HTTP/1.1
for _, test := range []struct {
StreamRequestBody bool
DisablePreParseMultipartForm bool
}{
{false, false},
{false, true},
{true, false},
{true, true},
} {
reqS := `POST /upload HTTP/1.1
Host: qwerty.com
Content-Length: 521
Content-Type: multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg
@@ -1100,91 +1109,94 @@ Connection: close
`
ln := fasthttputil.NewInmemoryListener()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
switch string(ctx.Path()) {
case "/upload":
f, err := ctx.MultipartForm()
if err != nil {
t.Errorf("unexpected error: %s", err)
s := &Server{
StreamRequestBody: test.StreamRequestBody,
DisablePreParseMultipartForm: test.DisablePreParseMultipartForm,
Handler: func(ctx *RequestCtx) {
switch string(ctx.Path()) {
case "/upload":
f, err := ctx.MultipartForm()
if err != nil {
t.Errorf("unexpected error: %s", err)
}
if len(f.Value) != 1 {
t.Errorf("unexpected values %d. Expecting %d", len(f.Value), 1)
}
if len(f.File) != 1 {
t.Errorf("unexpected file values %d. Expecting %d", len(f.File), 1)
}
fv := ctx.FormValue("f1")
if string(fv) != "value1" {
t.Errorf("unexpected form value: %q. Expecting %q", fv, "value1")
}
ctx.Redirect("/", StatusSeeOther)
default:
ctx.WriteString("non-upload") //nolint:errcheck
}
if len(f.Value) != 1 {
t.Errorf("unexpected values %d. Expecting %d", len(f.Value), 1)
}
if len(f.File) != 1 {
t.Errorf("unexpected file values %d. Expecting %d", len(f.File), 1)
}
fv := ctx.FormValue("f1")
if string(fv) != "value1" {
t.Errorf("unexpected form value: %q. Expecting %q", fv, "value1")
}
ctx.Redirect("/", StatusSeeOther)
default:
ctx.WriteString("non-upload") //nolint:errcheck
},
}
ch := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
},
}
close(ch)
}()
ch := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
close(ch)
}()
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if _, err = conn.Write([]byte(reqS)); err != nil {
t.Fatalf("unexpected error: %s", err)
}
var resp Response
br := bufio.NewReader(conn)
respCh := make(chan struct{})
go func() {
if err := resp.Read(br); err != nil {
t.Errorf("error when reading response: %s", err)
}
if resp.StatusCode() != StatusSeeOther {
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusSeeOther)
}
loc := resp.Header.Peek(HeaderLocation)
if string(loc) != "http://qwerty.com/" {
t.Errorf("unexpected location %q. Expecting %q", loc, "http://qwerty.com/")
if _, err = conn.Write([]byte(reqS)); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if err := resp.Read(br); err != nil {
t.Errorf("error when reading the second response: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
body := resp.Body()
if string(body) != "non-upload" {
t.Errorf("unexpected body %q. Expecting %q", body, "non-upload")
}
close(respCh)
}()
var resp Response
br := bufio.NewReader(conn)
respCh := make(chan struct{})
go func() {
if err := resp.Read(br); err != nil {
t.Errorf("error when reading response: %s", err)
}
if resp.StatusCode() != StatusSeeOther {
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusSeeOther)
}
loc := resp.Header.Peek(HeaderLocation)
if string(loc) != "http://qwerty.com/" {
t.Errorf("unexpected location %q. Expecting %q", loc, "http://qwerty.com/")
}
select {
case <-respCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
if err := resp.Read(br); err != nil {
t.Errorf("error when reading the second response: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
body := resp.Body()
if string(body) != "non-upload" {
t.Errorf("unexpected body %q. Expecting %q", body, "non-upload")
}
close(respCh)
}()
if err := ln.Close(); err != nil {
t.Fatalf("error when closing listener: %s", err)
}
select {
case <-respCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
select {
case <-ch:
case <-time.After(time.Second):
t.Fatal("timeout when waiting for the server to stop")
if err := ln.Close(); err != nil {
t.Fatalf("error when closing listener: %s", err)
}
select {
case <-ch:
case <-time.After(time.Second):
t.Fatal("timeout when waiting for the server to stop")
}
}
}
@@ -3413,8 +3425,8 @@ func TestMaxBodySizePerRequest(t *testing.T) {
func TestStreamRequestBody(t *testing.T) {
t.Parallel()
part1 := strings.Repeat("1", 1<<10)
part2 := strings.Repeat("2", 1<<20-1<<10)
part1 := strings.Repeat("1", 1<<15)
part2 := strings.Repeat("2", 1<<16)
contentLength := len(part1) + len(part2)
next := make(chan struct{})
@@ -3424,15 +3436,17 @@ func TestStreamRequestBody(t *testing.T) {
close(next)
checkReader(t, ctx.RequestBodyStream(), part2)
},
DisableKeepalive: true,
StreamRequestBody: true,
}
pipe := fasthttputil.NewPipeConns()
cc, sc := pipe.Conn1(), pipe.Conn2()
//write headers and part1 body
if _, err := cc.Write([]byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n%s", contentLength, part1))); err != nil {
t.Error(err)
if _, err := cc.Write([]byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n", contentLength))); err != nil {
t.Fatal(err)
}
if _, err := cc.Write([]byte(part1)); err != nil {
t.Fatal(err)
}
ch := make(chan error)
@@ -3447,12 +3461,15 @@ func TestStreamRequestBody(t *testing.T) {
}
if _, err := cc.Write([]byte(part2)); err != nil {
t.Error(err)
t.Fatal(err)
}
if err := sc.Close(); err != nil {
t.Fatal(err)
}
select {
case err := <-ch:
if err != nil {
if err == nil || err.Error() != "connection closed" { // fasthttputil.errConnectionClosed is private so do a string match.
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(500 * time.Millisecond):
+10 -1
View File
@@ -45,7 +45,12 @@ func (rs *requestStream) Read(p []byte) (int, error) {
}
var n int
var err error
if int(rs.prefetchedBytes.Size()) > rs.totalBytesRead {
prefetchedSize := int(rs.prefetchedBytes.Size())
if prefetchedSize > rs.totalBytesRead {
left := prefetchedSize - rs.totalBytesRead
if len(p) > left {
p = p[:left]
}
n, err := rs.prefetchedBytes.Read(p)
rs.totalBytesRead += n
if n == rs.contentLength {
@@ -53,6 +58,10 @@ func (rs *requestStream) Read(p []byte) (int, error) {
}
return n, err
} else {
left := rs.contentLength - rs.totalBytesRead
if len(p) > left {
p = p[:left]
}
n, err = rs.reader.Read(p)
rs.totalBytesRead += n
if err != nil {
+90
View File
@@ -6,10 +6,100 @@ import (
"io/ioutil"
"sync"
"testing"
"time"
"github.com/valyala/fasthttp/fasthttputil"
)
func TestStreamingPipeline(t *testing.T) {
t.Parallel()
reqS := `POST /one HTTP/1.1
Host: example.com
Content-Length: 10
aaaaaaaaaa
POST /two HTTP/1.1
Host: example.com
Content-Length: 10
aaaaaaaaaa`
ln := fasthttputil.NewInmemoryListener()
s := &Server{
StreamRequestBody: true,
Handler: func(ctx *RequestCtx) {
body := ""
expected := "aaaaaaaaaa"
if string(ctx.Path()) == "/one" {
body = string(ctx.PostBody())
} else {
all, err := ioutil.ReadAll(ctx.RequestBodyStream())
if err != nil {
t.Error(err)
}
body = string(all)
}
if body != expected {
t.Errorf("expected %q got %q", expected, body)
}
},
}
ch := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
close(ch)
}()
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if _, err = conn.Write([]byte(reqS)); err != nil {
t.Fatalf("unexpected error: %s", err)
}
var resp Response
br := bufio.NewReader(conn)
respCh := make(chan struct{})
go func() {
if err := resp.Read(br); err != nil {
t.Errorf("error when reading response: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
}
if err := resp.Read(br); err != nil {
t.Errorf("error when reading response: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
}
close(respCh)
}()
select {
case <-respCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
if err := ln.Close(); err != nil {
t.Fatalf("error when closing listener: %s", err)
}
select {
case <-ch:
case <-time.After(time.Second):
t.Fatal("timeout when waiting for the server to stop")
}
}
func TestRequestStream(t *testing.T) {
body := createFixedBody(3)
chunkedBody := createChunkedBody(body)