allow the expect 100 continue workflow to deny requests (#787)

* allow the expect 100 continue workflow to deny requests

* suggested changes

* update booleans to reflect handler name change
This commit is contained in:
Mike MacDermaid
2020-04-27 13:29:02 -05:00
committed by GitHub
parent 446e1a638d
commit 32940977fb
2 changed files with 153 additions and 30 deletions
+63 -30
View File
@@ -172,6 +172,17 @@ type Server struct {
// non zero RequestConfig field values will overwrite the default configs
HeaderReceived func(header *RequestHeader) RequestConfig
// ContinueHandler is called after receiving the Expect 100 Continue Header
//
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html#sec8.2.3
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.1.1
// Using ContinueHandler a server can make decisioning on whether or not
// to read a potentially large request body based on the headers
//
// The default is to automatically read request bodies of Expect 100 Continue requests
// like they are normal requests
ContinueHandler func(header *RequestHeader) bool
// Server name for sending in response headers.
//
// Default server name is used if left blank.
@@ -1935,7 +1946,8 @@ func (s *Server) serveConn(c net.Conn) (err error) {
connectionClose bool
isHTTP11 bool
reqReset bool
reqReset bool
continueReadingRequest bool = true
)
for {
connRequestNum++
@@ -2007,6 +2019,7 @@ func (s *Server) serveConn(c net.Conn) (err error) {
//read body
err = ctx.Request.readLimitBody(br, maxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm)
}
if err == nil {
// If we read any bytes off the wire, we're active.
s.setState(c, StateActive)
@@ -2041,37 +2054,53 @@ func (s *Server) serveConn(c net.Conn) (err error) {
}
// 'Expect: 100-continue' request handling.
// See http://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html for details.
// See https://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html#sec8.2.3 for details.
if ctx.Request.MayContinue() {
// Send 'HTTP/1.1 100 Continue' response.
if bw == nil {
bw = acquireWriter(ctx)
}
_, err = bw.Write(strResponseContinue)
if err != nil {
break
}
err = bw.Flush()
if err != nil {
break
}
if s.ReduceMemoryUsage {
releaseWriter(s, bw)
bw = nil
// Allow the ability to deny reading the incoming request body
if s.ContinueHandler != nil {
if continueReadingRequest = s.ContinueHandler(&ctx.Request.Header); !continueReadingRequest {
if br != nil {
br.Reset(ctx.c)
}
ctx.SetStatusCode(StatusExpectationFailed)
}
}
// Read request body.
if br == nil {
br = acquireReader(ctx)
}
err = ctx.Request.ContinueReadBody(br, maxRequestBodySize, !s.DisablePreParseMultipartForm)
if (s.ReduceMemoryUsage && br.Buffered() == 0) || err != nil {
releaseReader(s, br)
br = nil
}
if err != nil {
bw = s.writeErrorResponse(bw, ctx, serverName, err)
break
if continueReadingRequest {
if bw == nil {
bw = acquireWriter(ctx)
}
// Send 'HTTP/1.1 100 Continue' response.
_, err = bw.Write(strResponseContinue)
if err != nil {
break
}
err = bw.Flush()
if err != nil {
break
}
if s.ReduceMemoryUsage {
releaseWriter(s, bw)
bw = nil
}
// Read request body.
if br == nil {
br = acquireReader(ctx)
}
err = ctx.Request.ContinueReadBody(br, maxRequestBodySize, !s.DisablePreParseMultipartForm)
if (s.ReduceMemoryUsage && br.Buffered() == 0) || err != nil {
releaseReader(s, br)
br = nil
}
if err != nil {
bw = s.writeErrorResponse(bw, ctx, serverName, err)
break
}
}
}
@@ -2084,7 +2113,11 @@ func (s *Server) serveConn(c net.Conn) (err error) {
ctx.connID = connID
ctx.connRequestNum = connRequestNum
ctx.time = time.Now()
s.Handler(ctx)
// If a client denies a request the handler should not be called
if continueReadingRequest {
s.Handler(ctx)
}
timeoutResponse = ctx.timeoutResponse
if timeoutResponse != nil {
+90
View File
@@ -1654,6 +1654,96 @@ func TestServerExpect100Continue(t *testing.T) {
}
}
func TestServerContinueHandler(t *testing.T) {
t.Parallel()
acceptContentLength := 5
s := &Server{
ContinueHandler: func(headers *RequestHeader) bool {
if !headers.IsPost() {
t.Errorf("unexpected method %q. Expecting POST", headers.Method())
}
ct := headers.ContentType()
if string(ct) != "a/b" {
t.Errorf("unexpectected content-type: %q. Expecting %q", ct, "a/b")
}
// Pass on any request that isn't the accepted content length
return headers.contentLength == acceptContentLength
},
Handler: func(ctx *RequestCtx) {
if ctx.Request.Header.contentLength != acceptContentLength {
t.Errorf("all requests with content-length: other than %d, should be denied", acceptContentLength)
}
if !ctx.IsPost() {
t.Errorf("unexpected method %q. Expecting POST", ctx.Method())
}
if string(ctx.Path()) != "/foo" {
t.Errorf("unexpected path %q. Expecting %q", ctx.Path(), "/foo")
}
ct := ctx.Request.Header.ContentType()
if string(ct) != "a/b" {
t.Errorf("unexpectected content-type: %q. Expecting %q", ct, "a/b")
}
if string(ctx.PostBody()) != "12345" {
t.Errorf("unexpected body: %q. Expecting %q", ctx.PostBody(), "12345")
}
ctx.WriteString("foobar") //nolint:errcheck
},
}
sendRequest := func(rw *readWriter, expectedStatusCode int, expectedResponse string) {
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatal("timeout")
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, expectedStatusCode, string(defaultContentType), expectedResponse)
data, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %s", err)
}
if len(data) > 0 {
t.Fatalf("unexpected remaining data %q", data)
}
}
// The same server should not fail when handling the three different types of requests
// Regular requests
// Expect 100 continue accepted
// Exepect 100 continue denied
rw := &readWriter{}
for i := 0; i < 25; i++ {
// Regular requests without Expect 100 continue header
rw.r.Reset()
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345")
sendRequest(rw, StatusOK, "foobar")
// Regular Expect 100 continue reqeuests that are accepted
rw.r.Reset()
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345")
sendRequest(rw, StatusOK, "foobar")
// Requests being denied
rw.r.Reset()
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 6\r\nContent-Type: a/b\r\n\r\n123456")
sendRequest(rw, StatusExpectationFailed, "")
}
}
func TestCompressHandler(t *testing.T) {
t.Parallel()