mirror of
https://github.com/valyala/fasthttp.git
synced 2026-06-14 15:56:44 +03:00
allow the expect 100 continue workflow to deny requests (#787)
* allow the expect 100 continue workflow to deny requests * suggested changes * update booleans to reflect handler name change
This commit is contained in:
@@ -172,6 +172,17 @@ type Server struct {
|
||||
// non zero RequestConfig field values will overwrite the default configs
|
||||
HeaderReceived func(header *RequestHeader) RequestConfig
|
||||
|
||||
// ContinueHandler is called after receiving the Expect 100 Continue Header
|
||||
//
|
||||
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html#sec8.2.3
|
||||
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.1.1
|
||||
// Using ContinueHandler a server can make decisioning on whether or not
|
||||
// to read a potentially large request body based on the headers
|
||||
//
|
||||
// The default is to automatically read request bodies of Expect 100 Continue requests
|
||||
// like they are normal requests
|
||||
ContinueHandler func(header *RequestHeader) bool
|
||||
|
||||
// Server name for sending in response headers.
|
||||
//
|
||||
// Default server name is used if left blank.
|
||||
@@ -1935,7 +1946,8 @@ func (s *Server) serveConn(c net.Conn) (err error) {
|
||||
connectionClose bool
|
||||
isHTTP11 bool
|
||||
|
||||
reqReset bool
|
||||
reqReset bool
|
||||
continueReadingRequest bool = true
|
||||
)
|
||||
for {
|
||||
connRequestNum++
|
||||
@@ -2007,6 +2019,7 @@ func (s *Server) serveConn(c net.Conn) (err error) {
|
||||
//read body
|
||||
err = ctx.Request.readLimitBody(br, maxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
// If we read any bytes off the wire, we're active.
|
||||
s.setState(c, StateActive)
|
||||
@@ -2041,37 +2054,53 @@ func (s *Server) serveConn(c net.Conn) (err error) {
|
||||
}
|
||||
|
||||
// 'Expect: 100-continue' request handling.
|
||||
// See http://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html for details.
|
||||
// See https://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html#sec8.2.3 for details.
|
||||
if ctx.Request.MayContinue() {
|
||||
// Send 'HTTP/1.1 100 Continue' response.
|
||||
if bw == nil {
|
||||
bw = acquireWriter(ctx)
|
||||
}
|
||||
_, err = bw.Write(strResponseContinue)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = bw.Flush()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if s.ReduceMemoryUsage {
|
||||
releaseWriter(s, bw)
|
||||
bw = nil
|
||||
|
||||
// Allow the ability to deny reading the incoming request body
|
||||
if s.ContinueHandler != nil {
|
||||
if continueReadingRequest = s.ContinueHandler(&ctx.Request.Header); !continueReadingRequest {
|
||||
if br != nil {
|
||||
br.Reset(ctx.c)
|
||||
}
|
||||
|
||||
ctx.SetStatusCode(StatusExpectationFailed)
|
||||
}
|
||||
}
|
||||
|
||||
// Read request body.
|
||||
if br == nil {
|
||||
br = acquireReader(ctx)
|
||||
}
|
||||
err = ctx.Request.ContinueReadBody(br, maxRequestBodySize, !s.DisablePreParseMultipartForm)
|
||||
if (s.ReduceMemoryUsage && br.Buffered() == 0) || err != nil {
|
||||
releaseReader(s, br)
|
||||
br = nil
|
||||
}
|
||||
if err != nil {
|
||||
bw = s.writeErrorResponse(bw, ctx, serverName, err)
|
||||
break
|
||||
if continueReadingRequest {
|
||||
if bw == nil {
|
||||
bw = acquireWriter(ctx)
|
||||
}
|
||||
|
||||
// Send 'HTTP/1.1 100 Continue' response.
|
||||
_, err = bw.Write(strResponseContinue)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = bw.Flush()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if s.ReduceMemoryUsage {
|
||||
releaseWriter(s, bw)
|
||||
bw = nil
|
||||
}
|
||||
|
||||
// Read request body.
|
||||
if br == nil {
|
||||
br = acquireReader(ctx)
|
||||
}
|
||||
|
||||
err = ctx.Request.ContinueReadBody(br, maxRequestBodySize, !s.DisablePreParseMultipartForm)
|
||||
if (s.ReduceMemoryUsage && br.Buffered() == 0) || err != nil {
|
||||
releaseReader(s, br)
|
||||
br = nil
|
||||
}
|
||||
if err != nil {
|
||||
bw = s.writeErrorResponse(bw, ctx, serverName, err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2084,7 +2113,11 @@ func (s *Server) serveConn(c net.Conn) (err error) {
|
||||
ctx.connID = connID
|
||||
ctx.connRequestNum = connRequestNum
|
||||
ctx.time = time.Now()
|
||||
s.Handler(ctx)
|
||||
|
||||
// If a client denies a request the handler should not be called
|
||||
if continueReadingRequest {
|
||||
s.Handler(ctx)
|
||||
}
|
||||
|
||||
timeoutResponse = ctx.timeoutResponse
|
||||
if timeoutResponse != nil {
|
||||
|
||||
@@ -1654,6 +1654,96 @@ func TestServerExpect100Continue(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerContinueHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
acceptContentLength := 5
|
||||
s := &Server{
|
||||
ContinueHandler: func(headers *RequestHeader) bool {
|
||||
if !headers.IsPost() {
|
||||
t.Errorf("unexpected method %q. Expecting POST", headers.Method())
|
||||
}
|
||||
|
||||
ct := headers.ContentType()
|
||||
if string(ct) != "a/b" {
|
||||
t.Errorf("unexpectected content-type: %q. Expecting %q", ct, "a/b")
|
||||
}
|
||||
|
||||
// Pass on any request that isn't the accepted content length
|
||||
return headers.contentLength == acceptContentLength
|
||||
},
|
||||
Handler: func(ctx *RequestCtx) {
|
||||
if ctx.Request.Header.contentLength != acceptContentLength {
|
||||
t.Errorf("all requests with content-length: other than %d, should be denied", acceptContentLength)
|
||||
}
|
||||
if !ctx.IsPost() {
|
||||
t.Errorf("unexpected method %q. Expecting POST", ctx.Method())
|
||||
}
|
||||
if string(ctx.Path()) != "/foo" {
|
||||
t.Errorf("unexpected path %q. Expecting %q", ctx.Path(), "/foo")
|
||||
}
|
||||
ct := ctx.Request.Header.ContentType()
|
||||
if string(ct) != "a/b" {
|
||||
t.Errorf("unexpectected content-type: %q. Expecting %q", ct, "a/b")
|
||||
}
|
||||
if string(ctx.PostBody()) != "12345" {
|
||||
t.Errorf("unexpected body: %q. Expecting %q", ctx.PostBody(), "12345")
|
||||
}
|
||||
ctx.WriteString("foobar") //nolint:errcheck
|
||||
},
|
||||
}
|
||||
|
||||
sendRequest := func(rw *readWriter, expectedStatusCode int, expectedResponse string) {
|
||||
ch := make(chan error)
|
||||
go func() {
|
||||
ch <- s.ServeConn(rw)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-ch:
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error from serveConn: %s", err)
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
br := bufio.NewReader(&rw.w)
|
||||
verifyResponse(t, br, expectedStatusCode, string(defaultContentType), expectedResponse)
|
||||
|
||||
data, err := ioutil.ReadAll(br)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error when reading remaining data: %s", err)
|
||||
}
|
||||
if len(data) > 0 {
|
||||
t.Fatalf("unexpected remaining data %q", data)
|
||||
}
|
||||
}
|
||||
|
||||
// The same server should not fail when handling the three different types of requests
|
||||
// Regular requests
|
||||
// Expect 100 continue accepted
|
||||
// Exepect 100 continue denied
|
||||
rw := &readWriter{}
|
||||
for i := 0; i < 25; i++ {
|
||||
|
||||
// Regular requests without Expect 100 continue header
|
||||
rw.r.Reset()
|
||||
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345")
|
||||
sendRequest(rw, StatusOK, "foobar")
|
||||
|
||||
// Regular Expect 100 continue reqeuests that are accepted
|
||||
rw.r.Reset()
|
||||
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345")
|
||||
sendRequest(rw, StatusOK, "foobar")
|
||||
|
||||
// Requests being denied
|
||||
rw.r.Reset()
|
||||
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 6\r\nContent-Type: a/b\r\n\r\n123456")
|
||||
sendRequest(rw, StatusExpectationFailed, "")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompressHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user