diff --git a/header.go b/header.go index 70b910b..3013541 100644 --- a/header.go +++ b/header.go @@ -51,6 +51,9 @@ type RequestHeader struct { // It may be negative on chunked request. ContentLength int + // Set to true if request contains 'Connection: close' header. + ConnectionClose bool + host []byte contentType []byte userAgent []byte @@ -110,6 +113,7 @@ func (h *RequestHeader) Clear() { h.Method = h.Method[:0] h.RequestURI = h.RequestURI[:0] h.ContentLength = 0 + h.ConnectionClose = false h.host = h.host[:0] h.contentType = h.contentType[:0] @@ -137,6 +141,7 @@ func (h *RequestHeader) CopyTo(dst *RequestHeader) { dst.Method = append(dst.Method[:0], h.Method...) dst.RequestURI = append(dst.RequestURI[:0], h.RequestURI...) dst.ContentLength = h.ContentLength + dst.ConnectionClose = h.ConnectionClose dst.host = append(dst.host[:0], h.host...) dst.contentType = append(dst.contentType[:0], h.contentType...) dst.userAgent = append(dst.userAgent[:0], h.userAgent...) @@ -160,10 +165,10 @@ func (h *ResponseHeader) VisitAll(f func(key, value []byte)) { f(strSetCookie, v) }) } + visitArgs(h.h, f) if h.ConnectionClose { f(strConnection, strClose) } - visitArgs(h.h, f) } // VisitAllCookie calls f for each response cookie. @@ -203,6 +208,9 @@ func (h *RequestHeader) VisitAll(f func(key, value []byte)) { f(strCookie, h.bufKV.value) } visitArgs(h.h, f) + if h.ConnectionClose { + f(strConnection, strClose) + } } // Del deletes header with the given key. @@ -367,6 +375,11 @@ func (h *RequestHeader) SetCanonical(key, value []byte) { h.userAgent = append(h.userAgent[:0], value...) case bytes.Equal(strContentLength, key): // Content-Length is managed automatically. + case bytes.Equal(strConnection, key): + if bytes.Equal(strClose, value) { + h.ConnectionClose = true + } + // skip other 'Connection' shit :) case bytes.Equal(strTransferEncoding, key): // Transfer-Encoding is managed automatically. case bytes.Equal(strConnection, key): @@ -440,6 +453,11 @@ func (h *RequestHeader) peek(key []byte) []byte { return h.contentType case bytes.Equal(strUserAgent, key): return h.userAgent + case bytes.Equal(strConnection, key): + if h.ConnectionClose { + return strClose + } + return nil default: return peekArg(h.h, key) } @@ -642,10 +660,6 @@ func (h *ResponseHeader) Write(w *bufio.Writer) error { writeContentLength(w, h.ContentLength) } - if h.ConnectionClose { - writeHeaderLine(w, strConnection, strClose) - } - for i, n := 0, len(h.h); i < n; i++ { kv := &h.h[i] writeHeaderLine(w, kv.key, kv.value) @@ -659,6 +673,10 @@ func (h *ResponseHeader) Write(w *bufio.Writer) error { } } + if h.ConnectionClose { + writeHeaderLine(w, strConnection, strClose) + } + _, err := w.Write(strCRLF) return err } @@ -717,6 +735,10 @@ func (h *RequestHeader) Write(w *bufio.Writer) error { writeHeaderLine(w, strCookie, h.bufKV.value) } + if h.ConnectionClose { + writeHeaderLine(w, strConnection, strClose) + } + _, err := w.Write(strCRLF) return err } @@ -889,6 +911,10 @@ func (h *RequestHeader) parseHeaders(buf []byte) ([]byte, error) { if bytes.Equal(s.value, strChunked) { h.ContentLength = -1 } + case bytes.Equal(s.key, strConnection): + if bytes.Equal(s.value, strClose) { + h.ConnectionClose = true + } case bytes.Equal(s.key, strCookie): h.cookies = parseRequestCookies(h.cookies, s.value) default: diff --git a/header_test.go b/header_test.go index 5e70c74..f8ed6a6 100644 --- a/header_test.go +++ b/header_test.go @@ -54,6 +54,38 @@ func TestRequestHeaderCopyTo(t *testing.T) { } } +func TestRequestHeaderConnectionClose(t *testing.T) { + var h RequestHeader + + h.Set("Connection", "close") + h.Set("Host", "foobar") + if !h.ConnectionClose { + t.Fatalf("connection: close not set") + } + + var w bytes.Buffer + bw := bufio.NewWriter(&w) + if err := h.Write(bw); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if err := bw.Flush(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + var h1 RequestHeader + br := bufio.NewReader(&w) + if err := h1.Read(br); err != nil { + t.Fatalf("error when reading request header: %s", err) + } + + if !h1.ConnectionClose { + t.Fatalf("unexpected connection: close value: %v", h1.ConnectionClose) + } + if h1.Get("Connection") != "close" { + t.Fatalf("unexpected connection value: %q. Expecting %q", h.Get("Connection"), "close") + } +} + func TestRequestHeaderSetCookie(t *testing.T) { var h RequestHeader @@ -356,6 +388,7 @@ func TestRequestHeaderSetGet(t *testing.T) { h.Set("referer", "axcv") h.Set("baz", "xxxxx") h.Set("transfer-encoding", "chunked") + h.Set("connection", "close") expectRequestHeaderGet(t, h, "Foo", "bar") expectRequestHeaderGet(t, h, "Host", "12345") @@ -365,6 +398,10 @@ func TestRequestHeaderSetGet(t *testing.T) { expectRequestHeaderGet(t, h, "Referer", "axcv") expectRequestHeaderGet(t, h, "baz", "xxxxx") expectRequestHeaderGet(t, h, "Transfer-Encoding", "") + expectRequestHeaderGet(t, h, "connecTION", "close") + if !h.ConnectionClose { + t.Fatalf("unset connection: close") + } if h.ContentLength != 0 { t.Fatalf("Unexpected content-length %d. Expected %d", h.ContentLength, 0) @@ -398,6 +435,10 @@ func TestRequestHeaderSetGet(t *testing.T) { expectRequestHeaderGet(t, &h1, "Referer", "axcv") expectRequestHeaderGet(t, &h1, "baz", "xxxxx") expectRequestHeaderGet(t, &h1, "Transfer-Encoding", "") + expectRequestHeaderGet(t, &h1, "Connection", "close") + if !h1.ConnectionClose { + t.Fatalf("unset connection: close") + } } func TestResponseHeaderSetGet(t *testing.T) { @@ -580,14 +621,23 @@ func TestResponseHeaderReadSuccess(t *testing.T) { // straight order of content-length and content-type testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n\r\n", 200, 123, "text/html", "") + if h.ConnectionClose { + t.Fatalf("unexpected connection: close") + } // reverse order of content-length and content-type - testResponseHeaderReadSuccess(t, h, "HTTP/1.1 202 OK\r\nContent-Type: text/plain; encoding=utf-8\r\nContent-Length: 543\r\n\r\n", + testResponseHeaderReadSuccess(t, h, "HTTP/1.1 202 OK\r\nContent-Type: text/plain; encoding=utf-8\r\nContent-Length: 543\r\nConnection: close\r\n\r\n", 202, 543, "text/plain; encoding=utf-8", "") + if !h.ConnectionClose { + t.Fatalf("expecting connection: close") + } // tranfer-encoding: chunked testResponseHeaderReadSuccess(t, h, "HTTP/1.1 505 Internal error\r\nContent-Type: text/html\r\nTransfer-Encoding: chunked\r\n\r\n", 505, -1, "text/html", "") + if h.ConnectionClose { + t.Fatalf("unexpected connection: close") + } // reverse order of content-type and tranfer-encoding testResponseHeaderReadSuccess(t, h, "HTTP/1.1 343 foobar\r\nTransfer-Encoding: chunked\r\nContent-Type: text/json\r\n\r\n", @@ -655,14 +705,23 @@ func TestRequestHeaderReadSuccess(t *testing.T) { // simple headers testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\nHost: google.com\r\n\r\n", 0, "/foo/bar", "google.com", "", "", "") + if h.ConnectionClose { + t.Fatalf("unexpected connection: close header") + } // simple headers with body - testRequestHeaderReadSuccess(t, h, "GET /a/bar HTTP/1.1\r\nHost: gole.com\r\n\r\nfoobar", + testRequestHeaderReadSuccess(t, h, "GET /a/bar HTTP/1.1\r\nHost: gole.com\r\nconneCTION: close\r\n\r\nfoobar", 0, "/a/bar", "gole.com", "", "", "foobar") + if !h.ConnectionClose { + t.Fatalf("connection: close unset") + } // ancient http protocol testRequestHeaderReadSuccess(t, h, "GET /bar HTTP/1.0\r\nHost: gole\r\n\r\npppp", 0, "/bar", "gole", "", "", "pppp") + if h.ConnectionClose { + t.Fatalf("unexpected connection: close header") + } // complex headers with body testRequestHeaderReadSuccess(t, h, "GET /aabar HTTP/1.1\r\nAAA: bbb\r\nHost: ole.com\r\nAA: bb\r\n\r\nzzz", diff --git a/server.go b/server.go index c6ffc57..efcedd4 100644 --- a/server.go +++ b/server.go @@ -468,7 +468,7 @@ func (s *Server) serveConn(c net.Conn) error { if err = writeResponse(ctx, bw); err != nil { break } - connectionClose = ctx.Response.Header.ConnectionClose + connectionClose = ctx.Request.Header.ConnectionClose || ctx.Response.Header.ConnectionClose trimBigBuffers(ctx)