diff --git a/server.go b/server.go index 29ce909..80ffd39 100644 --- a/server.go +++ b/server.go @@ -172,6 +172,17 @@ type Server struct { // non zero RequestConfig field values will overwrite the default configs HeaderReceived func(header *RequestHeader) RequestConfig + // ContinueHandler is called after receiving the Expect 100 Continue Header + // + // https://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html#sec8.2.3 + // https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.1.1 + // Using ContinueHandler a server can make decisioning on whether or not + // to read a potentially large request body based on the headers + // + // The default is to automatically read request bodies of Expect 100 Continue requests + // like they are normal requests + ContinueHandler func(header *RequestHeader) bool + // Server name for sending in response headers. // // Default server name is used if left blank. @@ -1935,7 +1946,8 @@ func (s *Server) serveConn(c net.Conn) (err error) { connectionClose bool isHTTP11 bool - reqReset bool + reqReset bool + continueReadingRequest bool = true ) for { connRequestNum++ @@ -2007,6 +2019,7 @@ func (s *Server) serveConn(c net.Conn) (err error) { //read body err = ctx.Request.readLimitBody(br, maxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm) } + if err == nil { // If we read any bytes off the wire, we're active. s.setState(c, StateActive) @@ -2041,37 +2054,53 @@ func (s *Server) serveConn(c net.Conn) (err error) { } // 'Expect: 100-continue' request handling. - // See http://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html for details. + // See https://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html#sec8.2.3 for details. if ctx.Request.MayContinue() { - // Send 'HTTP/1.1 100 Continue' response. - if bw == nil { - bw = acquireWriter(ctx) - } - _, err = bw.Write(strResponseContinue) - if err != nil { - break - } - err = bw.Flush() - if err != nil { - break - } - if s.ReduceMemoryUsage { - releaseWriter(s, bw) - bw = nil + + // Allow the ability to deny reading the incoming request body + if s.ContinueHandler != nil { + if continueReadingRequest = s.ContinueHandler(&ctx.Request.Header); !continueReadingRequest { + if br != nil { + br.Reset(ctx.c) + } + + ctx.SetStatusCode(StatusExpectationFailed) + } } - // Read request body. - if br == nil { - br = acquireReader(ctx) - } - err = ctx.Request.ContinueReadBody(br, maxRequestBodySize, !s.DisablePreParseMultipartForm) - if (s.ReduceMemoryUsage && br.Buffered() == 0) || err != nil { - releaseReader(s, br) - br = nil - } - if err != nil { - bw = s.writeErrorResponse(bw, ctx, serverName, err) - break + if continueReadingRequest { + if bw == nil { + bw = acquireWriter(ctx) + } + + // Send 'HTTP/1.1 100 Continue' response. + _, err = bw.Write(strResponseContinue) + if err != nil { + break + } + err = bw.Flush() + if err != nil { + break + } + if s.ReduceMemoryUsage { + releaseWriter(s, bw) + bw = nil + } + + // Read request body. + if br == nil { + br = acquireReader(ctx) + } + + err = ctx.Request.ContinueReadBody(br, maxRequestBodySize, !s.DisablePreParseMultipartForm) + if (s.ReduceMemoryUsage && br.Buffered() == 0) || err != nil { + releaseReader(s, br) + br = nil + } + if err != nil { + bw = s.writeErrorResponse(bw, ctx, serverName, err) + break + } } } @@ -2084,7 +2113,11 @@ func (s *Server) serveConn(c net.Conn) (err error) { ctx.connID = connID ctx.connRequestNum = connRequestNum ctx.time = time.Now() - s.Handler(ctx) + + // If a client denies a request the handler should not be called + if continueReadingRequest { + s.Handler(ctx) + } timeoutResponse = ctx.timeoutResponse if timeoutResponse != nil { diff --git a/server_test.go b/server_test.go index a7ee71c..4252a6a 100644 --- a/server_test.go +++ b/server_test.go @@ -1654,6 +1654,96 @@ func TestServerExpect100Continue(t *testing.T) { } } +func TestServerContinueHandler(t *testing.T) { + t.Parallel() + + acceptContentLength := 5 + s := &Server{ + ContinueHandler: func(headers *RequestHeader) bool { + if !headers.IsPost() { + t.Errorf("unexpected method %q. Expecting POST", headers.Method()) + } + + ct := headers.ContentType() + if string(ct) != "a/b" { + t.Errorf("unexpectected content-type: %q. Expecting %q", ct, "a/b") + } + + // Pass on any request that isn't the accepted content length + return headers.contentLength == acceptContentLength + }, + Handler: func(ctx *RequestCtx) { + if ctx.Request.Header.contentLength != acceptContentLength { + t.Errorf("all requests with content-length: other than %d, should be denied", acceptContentLength) + } + if !ctx.IsPost() { + t.Errorf("unexpected method %q. Expecting POST", ctx.Method()) + } + if string(ctx.Path()) != "/foo" { + t.Errorf("unexpected path %q. Expecting %q", ctx.Path(), "/foo") + } + ct := ctx.Request.Header.ContentType() + if string(ct) != "a/b" { + t.Errorf("unexpectected content-type: %q. Expecting %q", ct, "a/b") + } + if string(ctx.PostBody()) != "12345" { + t.Errorf("unexpected body: %q. Expecting %q", ctx.PostBody(), "12345") + } + ctx.WriteString("foobar") //nolint:errcheck + }, + } + + sendRequest := func(rw *readWriter, expectedStatusCode int, expectedResponse string) { + ch := make(chan error) + go func() { + ch <- s.ServeConn(rw) + }() + + select { + case err := <-ch: + if err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout") + } + + br := bufio.NewReader(&rw.w) + verifyResponse(t, br, expectedStatusCode, string(defaultContentType), expectedResponse) + + data, err := ioutil.ReadAll(br) + if err != nil { + t.Fatalf("Unexpected error when reading remaining data: %s", err) + } + if len(data) > 0 { + t.Fatalf("unexpected remaining data %q", data) + } + } + + // The same server should not fail when handling the three different types of requests + // Regular requests + // Expect 100 continue accepted + // Exepect 100 continue denied + rw := &readWriter{} + for i := 0; i < 25; i++ { + + // Regular requests without Expect 100 continue header + rw.r.Reset() + rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345") + sendRequest(rw, StatusOK, "foobar") + + // Regular Expect 100 continue reqeuests that are accepted + rw.r.Reset() + rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345") + sendRequest(rw, StatusOK, "foobar") + + // Requests being denied + rw.r.Reset() + rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 6\r\nContent-Type: a/b\r\n\r\n123456") + sendRequest(rw, StatusExpectationFailed, "") + } +} + func TestCompressHandler(t *testing.T) { t.Parallel()