diff --git a/client.go b/client.go index 676a30a..cce9db2 100644 --- a/client.go +++ b/client.go @@ -545,7 +545,7 @@ func (c *HostClient) do(req *Request, resp *Response, newConn bool) (bool, error } c.releaseReader(br) - if req.Header.ConnectionClose || resp.Header.ConnectionClose { + if req.Header.ConnectionClose() || resp.Header.ConnectionClose() { c.closeConn(cc) } else { c.releaseConn(cc) diff --git a/header.go b/header.go index e93bb67..c891cef 100644 --- a/header.go +++ b/header.go @@ -18,8 +18,7 @@ type ResponseHeader struct { // Response status code. StatusCode int - // Set to true if response contains 'Connection: close' header. - ConnectionClose bool + connectionClose bool contentLength int contentLengthBytes []byte @@ -38,8 +37,7 @@ type ResponseHeader struct { // It is forbidden copying RequestHeader instances. // Create new instances instead and use CopyTo. type RequestHeader struct { - // Set to true if request contains 'Connection: close' header. - ConnectionClose bool + connectionClose bool contentLength int contentLengthBytes []byte @@ -57,6 +55,26 @@ type RequestHeader struct { cookiesCollected bool } +// ConnectionClose returns true if 'Connection: close' header is set. +func (h *ResponseHeader) ConnectionClose() bool { + return h.connectionClose +} + +// SetConnectionClose sets 'Connection: close' header. +func (h *ResponseHeader) SetConnectionClose() { + h.connectionClose = true +} + +// ConnectionClose returns true if 'Connection: close' header is set. +func (h *RequestHeader) ConnectionClose() bool { + return h.connectionClose +} + +// SetConnectionClose sets 'Connection: close' header. +func (h *RequestHeader) SetConnectionClose() { + h.connectionClose = true +} + // ContentLength returns Content-Length header value. // // It may be negative: @@ -80,6 +98,7 @@ func (h *ResponseHeader) SetContentLength(contentLength int) { h.contentLengthBytes = h.contentLengthBytes[:0] value := strChunked if contentLength == -2 { + h.SetConnectionClose() value = strIdentity } h.h = setArg(h.h, strTransferEncoding, value) @@ -272,7 +291,7 @@ func (h *RequestHeader) Len() int { // Clear clears response header. func (h *ResponseHeader) Clear() { h.StatusCode = 0 - h.ConnectionClose = false + h.connectionClose = false h.contentLength = 0 h.contentLengthBytes = h.contentLengthBytes[:0] @@ -286,7 +305,7 @@ func (h *ResponseHeader) Clear() { // Clear clears request header. func (h *RequestHeader) Clear() { - h.ConnectionClose = false + h.connectionClose = false h.contentLength = 0 h.contentLengthBytes = h.contentLengthBytes[:0] @@ -306,7 +325,7 @@ func (h *RequestHeader) Clear() { func (h *ResponseHeader) CopyTo(dst *ResponseHeader) { dst.Clear() dst.StatusCode = h.StatusCode - dst.ConnectionClose = h.ConnectionClose + dst.connectionClose = h.connectionClose dst.contentLength = h.contentLength dst.contentLengthBytes = append(dst.contentLengthBytes[:0], h.contentLengthBytes...) dst.contentType = append(dst.contentType[:0], h.contentType...) @@ -318,7 +337,7 @@ func (h *ResponseHeader) CopyTo(dst *ResponseHeader) { // CopyTo copies all the headers to dst. func (h *RequestHeader) CopyTo(dst *RequestHeader) { dst.Clear() - dst.ConnectionClose = h.ConnectionClose + dst.connectionClose = h.connectionClose dst.contentLength = h.contentLength dst.contentLengthBytes = append(dst.contentLengthBytes[:0], h.contentLengthBytes...) dst.method = append(dst.method[:0], h.method...) @@ -353,7 +372,7 @@ func (h *ResponseHeader) VisitAll(f func(key, value []byte)) { }) } visitArgs(h.h, f) - if h.ConnectionClose { + if h.ConnectionClose() { f(strConnection, strClose) } } @@ -404,7 +423,7 @@ func (h *RequestHeader) VisitAll(f func(key, value []byte)) { f(strCookie, h.bufKV.value) } visitArgs(h.h, f) - if h.ConnectionClose { + if h.ConnectionClose() { f(strConnection, strClose) } } @@ -488,7 +507,7 @@ func (h *ResponseHeader) SetCanonical(key, value []byte) { } case bytes.Equal(strConnection, key): if bytes.Equal(strClose, value) { - h.ConnectionClose = true + h.SetConnectionClose() } // skip other 'Connection' shit :) case bytes.Equal(strTransferEncoding, key): @@ -583,7 +602,7 @@ func (h *RequestHeader) SetCanonical(key, value []byte) { } case bytes.Equal(strConnection, key): if bytes.Equal(strClose, value) { - h.ConnectionClose = true + h.SetConnectionClose() } // skip other 'Connection' shit :) case bytes.Equal(strTransferEncoding, key): @@ -640,7 +659,7 @@ func (h *ResponseHeader) peek(key []byte) []byte { case bytes.Equal(strServer, key): return h.Server() case bytes.Equal(strConnection, key): - if h.ConnectionClose { + if h.ConnectionClose() { return strClose } return nil @@ -660,7 +679,7 @@ func (h *RequestHeader) peek(key []byte) []byte { case bytes.Equal(strUserAgent, key): return h.UserAgent() case bytes.Equal(strConnection, key): - if h.ConnectionClose { + if h.ConnectionClose() { return strClose } return nil @@ -833,7 +852,7 @@ func (h *ResponseHeader) Write(w *bufio.Writer) error { } } - if h.ConnectionClose { + if h.ConnectionClose() { writeHeaderLine(w, strConnection, strClose) } @@ -889,7 +908,7 @@ func (h *RequestHeader) Write(w *bufio.Writer) error { writeHeaderLine(w, strCookie, h.bufKV.value) } - if h.ConnectionClose { + if h.ConnectionClose() { writeHeaderLine(w, strConnection, strClose) } @@ -935,7 +954,7 @@ func (h *ResponseHeader) parseFirstLine(buf []byte) (b []byte, err error) { } if !bytes.Equal(b[:n], strHTTP11) { // Non-http/1.1 response. Close connection after it. - h.ConnectionClose = true + h.SetConnectionClose() } b = b[n+1:] @@ -971,13 +990,13 @@ func (h *RequestHeader) parseFirstLine(buf []byte) (b []byte, err error) { n = bytes.LastIndexByte(b, ' ') if n < 0 { // no http protocol found. Close connection after the request. - h.ConnectionClose = true + h.SetConnectionClose() n = len(b) } else if n == 0 { return nil, fmt.Errorf("RequestURI cannot be empty in %q", buf) } else if !bytes.Equal(b[n+1:], strHTTP11) { // non-http/1.1 protocol. Close connection after the request. - h.ConnectionClose = true + h.SetConnectionClose() } h.SetRequestURIBytes(b[:n]) @@ -1017,7 +1036,7 @@ func (h *ResponseHeader) parseHeaders(buf []byte) ([]byte, error) { } case bytes.Equal(s.key, strConnection): if bytes.Equal(s.value, strClose) { - h.ConnectionClose = true + h.SetConnectionClose() } case bytes.Equal(s.key, strSetCookie): h.cookies, kv = allocArg(h.cookies) @@ -1037,8 +1056,6 @@ func (h *ResponseHeader) parseHeaders(buf []byte) ([]byte, error) { return nil, fmt.Errorf("missing required Content-Type header in %q", buf) } if contentLength == -2 { - // Close connection after 'identity' response. - h.ConnectionClose = true h.SetContentLength(contentLength) } return s.b, nil @@ -1078,7 +1095,7 @@ func (h *RequestHeader) parseHeaders(buf []byte) ([]byte, error) { } case bytes.Equal(s.key, strConnection): if bytes.Equal(s.value, strClose) { - h.ConnectionClose = true + h.SetConnectionClose() } default: h.h, kv = allocArg(h.h) diff --git a/header_test.go b/header_test.go index 4045511..2d81ee8 100644 --- a/header_test.go +++ b/header_test.go @@ -92,8 +92,8 @@ func testResponseHeaderHTTPVer(t *testing.T, s string, connectionClose bool) { if err := h.Read(br); err != nil { t.Fatalf("unexpected error: %s. response=%q", err, s) } - if h.ConnectionClose != connectionClose { - t.Fatalf("unexpected connectionClose %v. Expecting %v. response=%q", h.ConnectionClose, connectionClose, s) + if h.ConnectionClose() != connectionClose { + t.Fatalf("unexpected connectionClose %v. Expecting %v. response=%q", h.ConnectionClose(), connectionClose, s) } } @@ -105,8 +105,8 @@ func testRequestHeaderHTTPVer(t *testing.T, s string, connectionClose bool) { if err := h.Read(br); err != nil { t.Fatalf("unexpected error: %s. request=%q", err, s) } - if h.ConnectionClose != connectionClose { - t.Fatalf("unexpected connectionClose %v. Expecting %v. request=%q", h.ConnectionClose, connectionClose, s) + if h.ConnectionClose() != connectionClose { + t.Fatalf("unexpected connectionClose %v. Expecting %v. request=%q", h.ConnectionClose(), connectionClose, s) } } @@ -159,7 +159,7 @@ func TestRequestHeaderConnectionClose(t *testing.T) { h.Set("Connection", "close") h.Set("Host", "foobar") - if !h.ConnectionClose { + if !h.ConnectionClose() { t.Fatalf("connection: close not set") } @@ -178,8 +178,8 @@ func TestRequestHeaderConnectionClose(t *testing.T) { t.Fatalf("error when reading request header: %s", err) } - if !h1.ConnectionClose { - t.Fatalf("unexpected connection: close value: %v", h1.ConnectionClose) + if !h1.ConnectionClose() { + t.Fatalf("unexpected connection: close value: %v", h1.ConnectionClose()) } if string(h1.Peek("Connection")) != "close" { t.Fatalf("unexpected connection value: %q. Expecting %q", h.Peek("Connection"), "close") @@ -507,7 +507,7 @@ func TestRequestHeaderSetGet(t *testing.T) { expectRequestHeaderGet(t, h, "baz", "xxxxx") expectRequestHeaderGet(t, h, "Transfer-Encoding", "") expectRequestHeaderGet(t, h, "connecTION", "close") - if !h.ConnectionClose { + if !h.ConnectionClose() { t.Fatalf("unset connection: close") } @@ -544,7 +544,7 @@ func TestRequestHeaderSetGet(t *testing.T) { expectRequestHeaderGet(t, &h1, "baz", "xxxxx") expectRequestHeaderGet(t, &h1, "Transfer-Encoding", "") expectRequestHeaderGet(t, &h1, "Connection", "close") - if !h1.ConnectionClose { + if !h1.ConnectionClose() { t.Fatalf("unset connection: close") } } @@ -570,8 +570,8 @@ func TestResponseHeaderSetGet(t *testing.T) { if h.ContentLength() != 1234 { t.Fatalf("Unexpected content-length %d. Expected %d", h.ContentLength(), 1234) } - if !h.ConnectionClose { - t.Fatalf("Unexpected Connection: close value %v. Expected %v", h.ConnectionClose, true) + if !h.ConnectionClose() { + t.Fatalf("Unexpected Connection: close value %v. Expected %v", h.ConnectionClose(), true) } w := &bytes.Buffer{} @@ -593,8 +593,8 @@ func TestResponseHeaderSetGet(t *testing.T) { if h1.ContentLength() != h.ContentLength() { t.Fatalf("Unexpected Content-Length %d. Expected %d", h1.ContentLength(), h.ContentLength()) } - if h1.ConnectionClose != h.ConnectionClose { - t.Fatalf("unexpected connection: close %v. Expected %v", h1.ConnectionClose, h.ConnectionClose) + if h1.ConnectionClose() != h.ConnectionClose() { + t.Fatalf("unexpected connection: close %v. Expected %v", h1.ConnectionClose(), h.ConnectionClose()) } expectResponseHeaderGet(t, &h1, "Foo", "bar") @@ -622,8 +622,9 @@ func TestResponseHeaderConnectionClose(t *testing.T) { } func testResponseHeaderConnectionClose(t *testing.T, connectionClose bool) { - h := &ResponseHeader{ - ConnectionClose: connectionClose, + h := &ResponseHeader{} + if connectionClose { + h.SetConnectionClose() } h.SetContentLength(123) @@ -643,8 +644,8 @@ func testResponseHeaderConnectionClose(t *testing.T, connectionClose bool) { if err != nil { t.Fatalf("Unexpected error when reading response header: %s", err) } - if h1.ConnectionClose != h.ConnectionClose { - t.Fatalf("Unexpected value for ConnectionClose: %v. Expected %v", h1.ConnectionClose, h.ConnectionClose) + if h1.ConnectionClose() != h.ConnectionClose() { + t.Fatalf("Unexpected value for ConnectionClose: %v. Expected %v", h1.ConnectionClose(), h.ConnectionClose()) } } @@ -731,21 +732,21 @@ func TestResponseHeaderReadSuccess(t *testing.T) { // straight order of content-length and content-type testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n\r\n", 200, 123, "text/html", "") - if h.ConnectionClose { + if h.ConnectionClose() { t.Fatalf("unexpected connection: close") } // reverse order of content-length and content-type testResponseHeaderReadSuccess(t, h, "HTTP/1.1 202 OK\r\nContent-Type: text/plain; encoding=utf-8\r\nContent-Length: 543\r\nConnection: close\r\n\r\n", 202, 543, "text/plain; encoding=utf-8", "") - if !h.ConnectionClose { + if !h.ConnectionClose() { t.Fatalf("expecting connection: close") } // tranfer-encoding: chunked testResponseHeaderReadSuccess(t, h, "HTTP/1.1 505 Internal error\r\nContent-Type: text/html\r\nTransfer-Encoding: chunked\r\n\r\n", 505, -1, "text/html", "") - if h.ConnectionClose { + if h.ConnectionClose() { t.Fatalf("unexpected connection: close") } @@ -807,14 +808,14 @@ func TestResponseHeaderReadSuccess(t *testing.T) { // blank lines before the first line testResponseHeaderReadSuccess(t, h, "\r\nHTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 0\r\n\r\nsss", 200, 0, "aa", "sss") - if h.ConnectionClose { + if h.ConnectionClose() { t.Fatalf("unexpected connection: close") } // no content-length (identity transfer-encoding) testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Type: foo/bar\r\n\r\nabcdefg", 200, -2, "foo/bar", "abcdefg") - if !h.ConnectionClose { + if !h.ConnectionClose() { t.Fatalf("expecting connection: close for identity response") } } @@ -825,28 +826,28 @@ func TestRequestHeaderReadSuccess(t *testing.T) { // simple headers testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\nHost: google.com\r\n\r\n", 0, "/foo/bar", "google.com", "", "", "") - if h.ConnectionClose { + if h.ConnectionClose() { t.Fatalf("unexpected connection: close header") } // simple headers with body testRequestHeaderReadSuccess(t, h, "GET /a/bar HTTP/1.1\r\nHost: gole.com\r\nconneCTION: close\r\n\r\nfoobar", 0, "/a/bar", "gole.com", "", "", "foobar") - if !h.ConnectionClose { + if !h.ConnectionClose() { t.Fatalf("connection: close unset") } // ancient http protocol testRequestHeaderReadSuccess(t, h, "GET /bar HTTP/1.0\r\nHost: gole\r\n\r\npppp", 0, "/bar", "gole", "", "", "pppp") - if !h.ConnectionClose { + if !h.ConnectionClose() { t.Fatalf("expecting connectionClose for ancient http protocol") } // complex headers with body testRequestHeaderReadSuccess(t, h, "GET /aabar HTTP/1.1\r\nAAA: bbb\r\nHost: ole.com\r\nAA: bb\r\n\r\nzzz", 0, "/aabar", "ole.com", "", "", "zzz") - if h.ConnectionClose { + if h.ConnectionClose() { t.Fatalf("unexpected connection: close") } diff --git a/server.go b/server.go index 26f7392..bed6cc4 100644 --- a/server.go +++ b/server.go @@ -261,6 +261,12 @@ var zeroTCPAddr = &net.TCPAddr{ IP: net.IPv4zero, } +// SetConnectionClose sets 'Connection: close' response header and closes +// connection after the RequestHandler returns. +func (ctx *RequestCtx) SetConnectionClose() { + ctx.Request.Header.SetConnectionClose() +} + // RequestURI returns RequestURI. // // This uri is valid until returning from RequestHandler. @@ -673,7 +679,7 @@ func (s *Server) serveConn(c net.Conn) error { if err = writeResponse(ctx, bw); err != nil { break } - connectionClose = ctx.Request.Header.ConnectionClose || ctx.Response.Header.ConnectionClose + connectionClose = ctx.Request.Header.ConnectionClose() || ctx.Response.Header.ConnectionClose() if br == nil || connectionClose { err = bw.Flush() diff --git a/server_test.go b/server_test.go index ec5a0d8..c4c4f88 100644 --- a/server_test.go +++ b/server_test.go @@ -132,7 +132,7 @@ func TestServerTimeoutError(t *testing.T) { func TestServerConnectionClose(t *testing.T) { s := &Server{ Handler: func(ctx *RequestCtx) { - ctx.Response.Header.ConnectionClose = true + ctx.SetConnectionClose() }, } diff --git a/server_timing_test.go b/server_timing_test.go index a3cfc07..98142fe 100644 --- a/server_timing_test.go +++ b/server_timing_test.go @@ -283,7 +283,7 @@ func benchmarkServerGet(b *testing.B, clientsCount, requestsPerConn int) { } ctx.Success("text/plain", fakeResponse) if requestsPerConn == 1 { - ctx.Response.Header.ConnectionClose = true + ctx.SetConnectionClose() } registerServedRequest(b, ch) }, @@ -326,7 +326,7 @@ func benchmarkServerPost(b *testing.B, clientsCount, requestsPerConn int) { } ctx.Success("text/plain", body) if requestsPerConn == 1 { - ctx.Response.Header.ConnectionClose = true + ctx.SetConnectionClose() } registerServedRequest(b, ch) },