Pull request #24: added support for '100 Continue' responses and 'Expect: 100-continue' requests. Kudos to @celer

This commit is contained in:
Aliaksandr Valialkin
2015-12-25 16:11:20 +02:00
parent 9074415b7a
commit 3284c3e671
5 changed files with 208 additions and 30 deletions
+76 -19
View File
@@ -442,6 +442,14 @@ func (resp *Response) resetSkipHeader() {
// RemoveMultipartFormFiles or Reset must be called after
// reading multipart/form-data request in order to delete temporarily
// uploaded files.
//
// If MayContinue returns true, the caller must:
//
// - Either send StatusExpectationFailed response if request headers don't
// satisfy the caller.
// - Or send StatusContinue response before reading request body
// with ContinueReadBody.
// - Or close the connection.
func (req *Request) Read(r *bufio.Reader) error {
return req.ReadLimitBody(r, 0)
}
@@ -458,6 +466,14 @@ var errGetOnly = errors.New("non-GET request received")
// RemoveMultipartFormFiles or Reset must be called after
// reading multipart/form-data request in order to delete temporarily
// uploaded files.
//
// If MayContinue returns true, the caller must:
//
// - Either send StatusExpectationFailed response if request headers don't
// satisfy the caller.
// - Or send StatusContinue response before reading request body
// with ContinueReadBody.
// - Or close the connection.
func (req *Request) ReadLimitBody(r *bufio.Reader, maxBodySize int) error {
return req.readLimitBody(r, maxBodySize, false)
}
@@ -472,29 +488,64 @@ func (req *Request) readLimitBody(r *bufio.Reader, maxBodySize int, getOnly bool
return errGetOnly
}
if !req.Header.noBody() {
contentLength := req.Header.ContentLength()
if contentLength > 0 {
// Pre-read multipart form data of known length.
// This way we limit memory usage for large file uploads, since their contents
// is streamed into temporary files if file size exceeds defaultMaxInMemoryFileSize.
boundary := req.Header.MultipartFormBoundary()
if len(boundary) > 0 {
req.multipartForm, err = readMultipartFormBody(r, boundary, maxBodySize, defaultMaxInMemoryFileSize)
if err != nil {
req.Reset()
}
return err
}
}
if req.Header.noBody() {
return nil
}
req.body, err = readBody(r, contentLength, maxBodySize, req.body)
if err != nil {
req.Reset()
if req.MayContinue() {
// 'Expect: 100-continue' header found. Let the caller deciding
// whether to read request body or
// to return StatusExpectationFailed.
return nil
}
return req.ContinueReadBody(r, maxBodySize)
}
// MayContinue returns true if the request contains
// 'Expect: 100-continue' header.
//
// The caller must do one of the following actions if MayContinue returns true:
//
// - Either send StatusExpectationFailed response if request headers don't
// satisfy the caller.
// - Or send StatusContinue response before reading request body
// with ContinueReadBody.
// - Or close the connection.
func (req *Request) MayContinue() bool {
return bytes.Equal(req.Header.peek(strExpect), str100Continue)
}
// ContinueReadBody reads request body if request header contains
// 'Expect: 100-continue'.
//
// The caller must send StatusContinue response before calling this method.
//
// If maxBodySize > 0 and the body size exceeds maxBodySize,
// then ErrBodyTooLarge is returned.
func (req *Request) ContinueReadBody(r *bufio.Reader, maxBodySize int) error {
var err error
contentLength := req.Header.ContentLength()
if contentLength > 0 {
// Pre-read multipart form data of known length.
// This way we limit memory usage for large file uploads, since their contents
// is streamed into temporary files if file size exceeds defaultMaxInMemoryFileSize.
boundary := req.Header.MultipartFormBoundary()
if len(boundary) > 0 {
req.multipartForm, err = readMultipartFormBody(r, boundary, maxBodySize, defaultMaxInMemoryFileSize)
if err != nil {
req.Reset()
}
return err
}
req.Header.SetContentLength(len(req.body))
}
req.body, err = readBody(r, contentLength, maxBodySize, req.body)
if err != nil {
req.Reset()
return err
}
req.Header.SetContentLength(len(req.body))
return nil
}
@@ -513,6 +564,12 @@ func (resp *Response) ReadLimitBody(r *bufio.Reader, maxBodySize int) error {
if err != nil {
return err
}
if resp.Header.StatusCode() == StatusContinue {
// Read the next response according to http://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html .
if err = resp.Header.Read(r); err != nil {
return err
}
}
if !isSkipResponseBody(resp.Header.StatusCode()) && !resp.SkipBody {
resp.body, err = readBody(r, resp.Header.ContentLength(), maxBodySize, resp.body)
+53 -2
View File
@@ -4,11 +4,58 @@ import (
"bufio"
"bytes"
"fmt"
"io/ioutil"
"mime/multipart"
"strings"
"testing"
)
func TestRequestContinueReadBody(t *testing.T) {
s := "PUT /foo/bar HTTP/1.1\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343"
br := bufio.NewReader(bytes.NewBufferString(s))
var r Request
if err := r.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if !r.MayContinue() {
t.Fatalf("MayContinue must return true")
}
if err := r.ContinueReadBody(br, 0); err != nil {
t.Fatalf("error when reading request body: %s", err)
}
body := r.Body()
if string(body) != "abcde" {
t.Fatalf("unexpected body %q. Expecting %q", body, "abcde")
}
tail, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if string(tail) != "f4343" {
t.Fatalf("unexpected tail %q. Expecting %q", tail, "f4343")
}
}
func TestRequestMayContinue(t *testing.T) {
var r Request
if r.MayContinue() {
t.Fatalf("MayContinue on empty request must return false")
}
r.Header.Set("Expect", "123sdfds")
if r.MayContinue() {
t.Fatalf("MayContinue on invalid Expect header must return false")
}
r.Header.Set("Expect", "100-continue")
if !r.MayContinue() {
t.Fatalf("MayContinue on 'Expect: 100-continue' header must return true")
}
}
func TestResponseGzipStream(t *testing.T) {
var r Response
r.SetBodyStreamWriter(func(w *bufio.Writer) {
@@ -405,11 +452,15 @@ func TestResponseReadWithoutBody(t *testing.T) {
testResponseReadWithoutBody(t, &resp, "HTTP/1.1 204 Foo Bar\r\nContent-Type: aab\r\nTransfer-Encoding: chunked\r\n\r\n123\r\nss", false,
204, -1, "aab", "123\r\nss")
testResponseReadWithoutBody(t, &resp, "HTTP/1.1 100 AAA\r\nContent-Type: xxx\r\nContent-Length: 3434\r\n\r\naaaa", false,
100, 3434, "xxx", "aaaa")
testResponseReadWithoutBody(t, &resp, "HTTP/1.1 123 AAA\r\nContent-Type: xxx\r\nContent-Length: 3434\r\n\r\naaaa", false,
123, 3434, "xxx", "aaaa")
testResponseReadWithoutBody(t, &resp, "HTTP 200 OK\r\nContent-Type: text/xml\r\nContent-Length: 123\r\n\r\nxxxx", true,
200, 123, "text/xml", "xxxx")
// '100 Continue' must be skipped.
testResponseReadWithoutBody(t, &resp, "HTTP/1.1 100 Continue\r\nFoo-bar: baz\r\n\r\nHTTP/1.1 329 aaa\r\nContent-Type: qwe\r\nContent-Length: 894\r\n\r\nfoobar", true,
329, 894, "qwe", "foobar")
}
func testResponseReadWithoutBody(t *testing.T, resp *Response, s string, skipBody bool,
+34 -9
View File
@@ -1107,20 +1107,16 @@ func (s *Server) serveConn(c net.Conn) error {
if br == nil {
br = acquireReader(ctx)
}
} else {
br, err = acquireByteReader(&ctx)
}
if err == nil {
err = ctx.Request.readLimitBody(br, s.MaxRequestBodySize, s.GetOnly)
if br.Buffered() == 0 || err != nil {
releaseReader(s, br)
br = nil
}
} else {
br, err = acquireByteReader(&ctx)
if err == nil {
err = ctx.Request.ReadLimitBody(br, s.MaxRequestBodySize)
if br.Buffered() == 0 || err != nil {
releaseReader(s, br)
br = nil
}
}
}
currentTime = time.Now()
@@ -1133,6 +1129,35 @@ func (s *Server) serveConn(c net.Conn) error {
break
}
// 'Expect: 100-continue' request handling.
// See http://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html for details.
if !ctx.Request.Header.noBody() && ctx.Request.MayContinue() {
// Send 'HTTP/1.1 100 Continue' response.
if bw == nil {
bw = acquireWriter(ctx)
}
bw.Write(strResponseContinue)
err = bw.Flush()
releaseWriter(s, bw)
bw = nil
if err != nil {
break
}
// Read request body.
if br == nil {
br = acquireReader(ctx)
}
err = ctx.Request.ContinueReadBody(br, s.MaxRequestBodySize)
if br.Buffered() == 0 || err != nil {
releaseReader(s, br)
br = nil
}
if err != nil {
break
}
}
ctx.connRequestNum = connRequestNum
ctx.connTime = connTime
ctx.time = currentTime
+41
View File
@@ -12,6 +12,47 @@ import (
"time"
)
func TestServerExpect100Continue(t *testing.T) {
s := &Server{
Handler: func(ctx *RequestCtx) {
if !ctx.IsPost() {
t.Fatalf("unexpected method %q. Expecting POST", ctx.Method())
}
if string(ctx.Path()) != "/foo" {
t.Fatalf("unexpected path %q. Expecting %q", ctx.Path(), "/foo")
}
ct := ctx.Request.Header.ContentType()
if string(ct) != "a/b" {
t.Fatalf("unexpectected content-type: %q. Expecting %q", ct, "a/b")
}
if string(ctx.PostBody()) != "12345" {
t.Fatalf("unexpected body: %q. Expecting %q", ctx.PostBody(), "12345")
}
ctx.WriteString("foobar")
},
}
rw := &readWriter{}
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345")
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timeout")
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, StatusOK, string(defaultContentType), "foobar")
}
func TestCompressHandler(t *testing.T) {
expectedBody := "foo/bar/baz"
h := CompressHandler(func(ctx *RequestCtx) {
+4
View File
@@ -19,11 +19,14 @@ var (
strColonSlashSlash = []byte("://")
strColonSpace = []byte(": ")
strResponseContinue = []byte("HTTP/1.1 100 Continue\r\n\r\n")
strGet = []byte("GET")
strHead = []byte("HEAD")
strPost = []byte("POST")
strPut = []byte("PUT")
strExpect = []byte("Expect")
strConnection = []byte("Connection")
strContentLength = []byte("Content-Length")
strContentType = []byte("Content-Type")
@@ -53,6 +56,7 @@ var (
strUpgrade = []byte("Upgrade")
strChunked = []byte("chunked")
strIdentity = []byte("identity")
str100Continue = []byte("100-continue")
strPostArgsContentType = []byte("application/x-www-form-urlencoded")
strMultipartFormData = []byte("multipart/form-data")
strBoundary = []byte("boundary")