From 954a0615dc98de9d5d3d843bb59cd2db5d9a8e21 Mon Sep 17 00:00:00 2001 From: Aliaksandr Valialkin Date: Thu, 19 Nov 2015 11:51:04 +0200 Subject: [PATCH] Fixed a typo in RequestCtx.SetConnectionClose() --- http.go | 20 ++++++++++++++++++++ server.go | 11 +++++------ server_test.go | 43 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 6 deletions(-) diff --git a/http.go b/http.go index 7308aa4..c11119f 100644 --- a/http.go +++ b/http.go @@ -64,6 +64,26 @@ func (resp *Response) SetStatusCode(statusCode int) { resp.Header.SetStatusCode(statusCode) } +// ConnectionClose returns true if 'Connection: close' header is set. +func (resp *Response) ConnectionClose() bool { + return resp.Header.ConnectionClose() +} + +// SetConnectionClose sets 'Connection: close' header. +func (resp *Response) SetConnectionClose() { + resp.Header.SetConnectionClose() +} + +// ConnectionClose returns true if 'Connection: close' header is set. +func (req *Request) ConnectionClose() bool { + return req.Header.ConnectionClose() +} + +// SetConnectionClose sets 'Connection: close' header. +func (req *Request) SetConnectionClose() { + req.Header.SetConnectionClose() +} + // SetBodyStream sets response body stream and, optionally body size. // // If bodySize is >= 0, then bodySize bytes are read from bodyStream diff --git a/server.go b/server.go index 70584d8..015e73d 100644 --- a/server.go +++ b/server.go @@ -295,7 +295,7 @@ func (ctx *RequestCtx) ServeConnRequestNum() uint64 { // SetConnectionClose sets 'Connection: close' response header and closes // connection after the RequestHandler returns. func (ctx *RequestCtx) SetConnectionClose() { - ctx.Request.Header.SetConnectionClose() + ctx.Response.Header.SetConnectionClose() } // SetStatusCode sets response status code. @@ -748,6 +748,9 @@ func (s *Server) serveConn(c net.Conn) error { ctx = s.acquireCtx(c) ctx.Error(errMsg, StatusRequestTimeout) } + if s.MaxRequestsPerConn > 0 && ctx.serveConnRequestNum >= uint64(s.MaxRequestsPerConn) { + ctx.SetConnectionClose() + } if s.WriteTimeout > 0 { if err = c.SetWriteDeadline(time.Now().Add(s.WriteTimeout)); err != nil { @@ -760,11 +763,7 @@ func (s *Server) serveConn(c net.Conn) error { if err = writeResponse(ctx, bw); err != nil { break } - if s.MaxRequestsPerConn > 0 && ctx.serveConnRequestNum >= uint64(s.MaxRequestsPerConn) { - connectionClose = true - } else { - connectionClose = ctx.Response.Header.ConnectionClose() || ctx.Request.Header.ConnectionClose() - } + connectionClose = ctx.Response.Header.ConnectionClose() || ctx.Request.Header.ConnectionClose() if br == nil || connectionClose { err = bw.Flush() diff --git a/server_test.go b/server_test.go index 98e6a07..c6d7f2d 100644 --- a/server_test.go +++ b/server_test.go @@ -129,6 +129,49 @@ func TestServerTimeoutError(t *testing.T) { } } +func TestServerMaxRequestsPerConn(t *testing.T) { + s := &Server{ + Handler: func(ctx *RequestCtx) {}, + MaxRequestsPerConn: 1, + } + + rw := &readWriter{} + rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n") + rw.r.WriteString("GET /bar HTTP/1.1\r\nHost: aaa.com\r\n\r\n") + + ch := make(chan error) + go func() { + ch <- s.ServeConn(rw) + }() + + select { + case err := <-ch: + if err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("timeout") + } + + br := bufio.NewReader(&rw.w) + var resp Response + if err := resp.Read(br); err != nil { + t.Fatalf("Unexpected error when parsing response: %s", err) + } + if !resp.ConnectionClose() { + t.Fatalf("Response must have 'connection: close' header") + } + verifyResponseHeader(t, &resp.Header, 200, 0, string(defaultContentType)) + + data, err := ioutil.ReadAll(br) + if err != nil { + t.Fatalf("Unexpected error when reading remaining data: %s", err) + } + if len(data) != 0 { + t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "") + } +} + func TestServerConnectionClose(t *testing.T) { s := &Server{ Handler: func(ctx *RequestCtx) {