From 3284c3e6711972e6934d122a10bcaec8bf967026 Mon Sep 17 00:00:00 2001 From: Aliaksandr Valialkin Date: Fri, 25 Dec 2015 16:11:20 +0200 Subject: [PATCH] Pull request #24: added support for '100 Continue' responses and 'Expect: 100-continue' requests. Kudos to @celer --- http.go | 95 ++++++++++++++++++++++++++++++++++++++++---------- http_test.go | 55 +++++++++++++++++++++++++++-- server.go | 43 ++++++++++++++++++----- server_test.go | 41 ++++++++++++++++++++++ strings.go | 4 +++ 5 files changed, 208 insertions(+), 30 deletions(-) diff --git a/http.go b/http.go index 0fb7fa7..e7526da 100644 --- a/http.go +++ b/http.go @@ -442,6 +442,14 @@ func (resp *Response) resetSkipHeader() { // RemoveMultipartFormFiles or Reset must be called after // reading multipart/form-data request in order to delete temporarily // uploaded files. +// +// If MayContinue returns true, the caller must: +// +// - Either send StatusExpectationFailed response if request headers don't +// satisfy the caller. +// - Or send StatusContinue response before reading request body +// with ContinueReadBody. +// - Or close the connection. func (req *Request) Read(r *bufio.Reader) error { return req.ReadLimitBody(r, 0) } @@ -458,6 +466,14 @@ var errGetOnly = errors.New("non-GET request received") // RemoveMultipartFormFiles or Reset must be called after // reading multipart/form-data request in order to delete temporarily // uploaded files. +// +// If MayContinue returns true, the caller must: +// +// - Either send StatusExpectationFailed response if request headers don't +// satisfy the caller. +// - Or send StatusContinue response before reading request body +// with ContinueReadBody. +// - Or close the connection. func (req *Request) ReadLimitBody(r *bufio.Reader, maxBodySize int) error { return req.readLimitBody(r, maxBodySize, false) } @@ -472,29 +488,64 @@ func (req *Request) readLimitBody(r *bufio.Reader, maxBodySize int, getOnly bool return errGetOnly } - if !req.Header.noBody() { - contentLength := req.Header.ContentLength() - if contentLength > 0 { - // Pre-read multipart form data of known length. - // This way we limit memory usage for large file uploads, since their contents - // is streamed into temporary files if file size exceeds defaultMaxInMemoryFileSize. - boundary := req.Header.MultipartFormBoundary() - if len(boundary) > 0 { - req.multipartForm, err = readMultipartFormBody(r, boundary, maxBodySize, defaultMaxInMemoryFileSize) - if err != nil { - req.Reset() - } - return err - } - } + if req.Header.noBody() { + return nil + } - req.body, err = readBody(r, contentLength, maxBodySize, req.body) - if err != nil { - req.Reset() + if req.MayContinue() { + // 'Expect: 100-continue' header found. Let the caller deciding + // whether to read request body or + // to return StatusExpectationFailed. + return nil + } + + return req.ContinueReadBody(r, maxBodySize) +} + +// MayContinue returns true if the request contains +// 'Expect: 100-continue' header. +// +// The caller must do one of the following actions if MayContinue returns true: +// +// - Either send StatusExpectationFailed response if request headers don't +// satisfy the caller. +// - Or send StatusContinue response before reading request body +// with ContinueReadBody. +// - Or close the connection. +func (req *Request) MayContinue() bool { + return bytes.Equal(req.Header.peek(strExpect), str100Continue) +} + +// ContinueReadBody reads request body if request header contains +// 'Expect: 100-continue'. +// +// The caller must send StatusContinue response before calling this method. +// +// If maxBodySize > 0 and the body size exceeds maxBodySize, +// then ErrBodyTooLarge is returned. +func (req *Request) ContinueReadBody(r *bufio.Reader, maxBodySize int) error { + var err error + contentLength := req.Header.ContentLength() + if contentLength > 0 { + // Pre-read multipart form data of known length. + // This way we limit memory usage for large file uploads, since their contents + // is streamed into temporary files if file size exceeds defaultMaxInMemoryFileSize. + boundary := req.Header.MultipartFormBoundary() + if len(boundary) > 0 { + req.multipartForm, err = readMultipartFormBody(r, boundary, maxBodySize, defaultMaxInMemoryFileSize) + if err != nil { + req.Reset() + } return err } - req.Header.SetContentLength(len(req.body)) } + + req.body, err = readBody(r, contentLength, maxBodySize, req.body) + if err != nil { + req.Reset() + return err + } + req.Header.SetContentLength(len(req.body)) return nil } @@ -513,6 +564,12 @@ func (resp *Response) ReadLimitBody(r *bufio.Reader, maxBodySize int) error { if err != nil { return err } + if resp.Header.StatusCode() == StatusContinue { + // Read the next response according to http://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html . + if err = resp.Header.Read(r); err != nil { + return err + } + } if !isSkipResponseBody(resp.Header.StatusCode()) && !resp.SkipBody { resp.body, err = readBody(r, resp.Header.ContentLength(), maxBodySize, resp.body) diff --git a/http_test.go b/http_test.go index 11284ef..7ad9dbd 100644 --- a/http_test.go +++ b/http_test.go @@ -4,11 +4,58 @@ import ( "bufio" "bytes" "fmt" + "io/ioutil" "mime/multipart" "strings" "testing" ) +func TestRequestContinueReadBody(t *testing.T) { + s := "PUT /foo/bar HTTP/1.1\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" + br := bufio.NewReader(bytes.NewBufferString(s)) + + var r Request + if err := r.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if !r.MayContinue() { + t.Fatalf("MayContinue must return true") + } + + if err := r.ContinueReadBody(br, 0); err != nil { + t.Fatalf("error when reading request body: %s", err) + } + body := r.Body() + if string(body) != "abcde" { + t.Fatalf("unexpected body %q. Expecting %q", body, "abcde") + } + + tail, err := ioutil.ReadAll(br) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if string(tail) != "f4343" { + t.Fatalf("unexpected tail %q. Expecting %q", tail, "f4343") + } +} + +func TestRequestMayContinue(t *testing.T) { + var r Request + if r.MayContinue() { + t.Fatalf("MayContinue on empty request must return false") + } + + r.Header.Set("Expect", "123sdfds") + if r.MayContinue() { + t.Fatalf("MayContinue on invalid Expect header must return false") + } + + r.Header.Set("Expect", "100-continue") + if !r.MayContinue() { + t.Fatalf("MayContinue on 'Expect: 100-continue' header must return true") + } +} + func TestResponseGzipStream(t *testing.T) { var r Response r.SetBodyStreamWriter(func(w *bufio.Writer) { @@ -405,11 +452,15 @@ func TestResponseReadWithoutBody(t *testing.T) { testResponseReadWithoutBody(t, &resp, "HTTP/1.1 204 Foo Bar\r\nContent-Type: aab\r\nTransfer-Encoding: chunked\r\n\r\n123\r\nss", false, 204, -1, "aab", "123\r\nss") - testResponseReadWithoutBody(t, &resp, "HTTP/1.1 100 AAA\r\nContent-Type: xxx\r\nContent-Length: 3434\r\n\r\naaaa", false, - 100, 3434, "xxx", "aaaa") + testResponseReadWithoutBody(t, &resp, "HTTP/1.1 123 AAA\r\nContent-Type: xxx\r\nContent-Length: 3434\r\n\r\naaaa", false, + 123, 3434, "xxx", "aaaa") testResponseReadWithoutBody(t, &resp, "HTTP 200 OK\r\nContent-Type: text/xml\r\nContent-Length: 123\r\n\r\nxxxx", true, 200, 123, "text/xml", "xxxx") + + // '100 Continue' must be skipped. + testResponseReadWithoutBody(t, &resp, "HTTP/1.1 100 Continue\r\nFoo-bar: baz\r\n\r\nHTTP/1.1 329 aaa\r\nContent-Type: qwe\r\nContent-Length: 894\r\n\r\nfoobar", true, + 329, 894, "qwe", "foobar") } func testResponseReadWithoutBody(t *testing.T, resp *Response, s string, skipBody bool, diff --git a/server.go b/server.go index 1c93aac..4f10fca 100644 --- a/server.go +++ b/server.go @@ -1107,20 +1107,16 @@ func (s *Server) serveConn(c net.Conn) error { if br == nil { br = acquireReader(ctx) } + } else { + br, err = acquireByteReader(&ctx) + } + + if err == nil { err = ctx.Request.readLimitBody(br, s.MaxRequestBodySize, s.GetOnly) if br.Buffered() == 0 || err != nil { releaseReader(s, br) br = nil } - } else { - br, err = acquireByteReader(&ctx) - if err == nil { - err = ctx.Request.ReadLimitBody(br, s.MaxRequestBodySize) - if br.Buffered() == 0 || err != nil { - releaseReader(s, br) - br = nil - } - } } currentTime = time.Now() @@ -1133,6 +1129,35 @@ func (s *Server) serveConn(c net.Conn) error { break } + // 'Expect: 100-continue' request handling. + // See http://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html for details. + if !ctx.Request.Header.noBody() && ctx.Request.MayContinue() { + // Send 'HTTP/1.1 100 Continue' response. + if bw == nil { + bw = acquireWriter(ctx) + } + bw.Write(strResponseContinue) + err = bw.Flush() + releaseWriter(s, bw) + bw = nil + if err != nil { + break + } + + // Read request body. + if br == nil { + br = acquireReader(ctx) + } + err = ctx.Request.ContinueReadBody(br, s.MaxRequestBodySize) + if br.Buffered() == 0 || err != nil { + releaseReader(s, br) + br = nil + } + if err != nil { + break + } + } + ctx.connRequestNum = connRequestNum ctx.connTime = connTime ctx.time = currentTime diff --git a/server_test.go b/server_test.go index 43e6bf2..f9cb54a 100644 --- a/server_test.go +++ b/server_test.go @@ -12,6 +12,47 @@ import ( "time" ) +func TestServerExpect100Continue(t *testing.T) { + s := &Server{ + Handler: func(ctx *RequestCtx) { + if !ctx.IsPost() { + t.Fatalf("unexpected method %q. Expecting POST", ctx.Method()) + } + if string(ctx.Path()) != "/foo" { + t.Fatalf("unexpected path %q. Expecting %q", ctx.Path(), "/foo") + } + ct := ctx.Request.Header.ContentType() + if string(ct) != "a/b" { + t.Fatalf("unexpectected content-type: %q. Expecting %q", ct, "a/b") + } + if string(ctx.PostBody()) != "12345" { + t.Fatalf("unexpected body: %q. Expecting %q", ctx.PostBody(), "12345") + } + ctx.WriteString("foobar") + }, + } + + rw := &readWriter{} + rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345") + + ch := make(chan error) + go func() { + ch <- s.ServeConn(rw) + }() + + select { + case err := <-ch: + if err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("timeout") + } + + br := bufio.NewReader(&rw.w) + verifyResponse(t, br, StatusOK, string(defaultContentType), "foobar") +} + func TestCompressHandler(t *testing.T) { expectedBody := "foo/bar/baz" h := CompressHandler(func(ctx *RequestCtx) { diff --git a/strings.go b/strings.go index 02132a1..5863f66 100644 --- a/strings.go +++ b/strings.go @@ -19,11 +19,14 @@ var ( strColonSlashSlash = []byte("://") strColonSpace = []byte(": ") + strResponseContinue = []byte("HTTP/1.1 100 Continue\r\n\r\n") + strGet = []byte("GET") strHead = []byte("HEAD") strPost = []byte("POST") strPut = []byte("PUT") + strExpect = []byte("Expect") strConnection = []byte("Connection") strContentLength = []byte("Content-Length") strContentType = []byte("Content-Type") @@ -53,6 +56,7 @@ var ( strUpgrade = []byte("Upgrade") strChunked = []byte("chunked") strIdentity = []byte("identity") + str100Continue = []byte("100-continue") strPostArgsContentType = []byte("application/x-www-form-urlencoded") strMultipartFormData = []byte("multipart/form-data") strBoundary = []byte("boundary")