diff --git a/header.go b/header.go index 66d2be4..dae683a 100644 --- a/header.go +++ b/header.go @@ -63,38 +63,34 @@ type RequestHeader struct { // ConnectionClose returns true if 'Connection: close' header is set. func (h *ResponseHeader) ConnectionClose() bool { - if !h.rawHeadersParsed { - if h.ContentLength() == -2 { - return true - } - if bytes.Equal(peekRawHeader(h.rawHeaders, strConnection), strClose) { - return true - } - // h.parseRawHeaders() isn't called for performance reasons. - } + h.parseRawHeaders() return h.connectionClose } // SetConnectionClose sets 'Connection: close' header. func (h *ResponseHeader) SetConnectionClose() { - h.parseRawHeaders() + // h.parseRawHeaders() isn't called for performance reasons. h.connectionClose = true } // ConnectionClose returns true if 'Connection: close' header is set. func (h *RequestHeader) ConnectionClose() bool { + // h.parseRawHeaders() isn't called for performance reasons. if !h.rawHeadersParsed { - if bytes.Equal(peekRawHeader(h.rawHeaders, strConnection), strClose) { + if h.connectionClose { + return true + } + if hasRawHeader(h.rawHeaders, strConnectionClose) { + h.connectionClose = true return true } - // h.parseRawHeaders() isn't called for performance reasons. } return h.connectionClose } // SetConnectionClose sets 'Connection: close' header. func (h *RequestHeader) SetConnectionClose() { - h.parseRawHeaders() + // h.parseRawHeaders() isn't called for performance reasons. h.connectionClose = true } @@ -104,28 +100,7 @@ func (h *RequestHeader) SetConnectionClose() { // -1 means Transfer-Encoding: chunked. // -2 means Transfer-Encoding: identity. func (h *ResponseHeader) ContentLength() int { - if !h.rawHeadersParsed { - if h.contentLength < 0 || len(h.contentLengthBytes) > 0 { - return h.contentLength - } - value := peekRawHeader(h.rawHeaders, strTransferEncoding) - if len(value) > 0 { - if bytes.Equal(value, strIdentity) { - h.connectionClose = true - h.contentLength = -2 - return h.contentLength - } - h.contentLength = -1 - return h.contentLength - } - value = peekRawHeader(h.rawHeaders, strContentLength) - if contentLength, err := parseContentLength(value); err == nil { - h.contentLength = contentLength - h.contentLengthBytes = append(h.contentLengthBytes[:0], value...) - return contentLength - } - h.parseRawHeaders() - } + h.parseRawHeaders() return h.contentLength } @@ -156,62 +131,13 @@ func (h *ResponseHeader) SetContentLength(contentLength int) { // It may be negative: // -1 means Transfer-Encoding: chunked. func (h *RequestHeader) ContentLength() int { - if h.IsGet() || h.IsHead() { + if !h.IsPost() { return 0 } - if !h.rawHeadersParsed { - if h.contentLength < 0 || len(h.contentLengthBytes) > 0 { - return h.contentLength - } - value := peekRawHeader(h.rawHeaders, strTransferEncoding) - if len(value) > 0 { - h.contentLength = -1 - return h.contentLength - } - value = peekRawHeader(h.rawHeaders, strContentLength) - if contentLength, err := parseContentLength(value); err == nil { - h.contentLength = contentLength - h.contentLengthBytes = append(h.contentLengthBytes[:0], value...) - return contentLength - } - h.parseRawHeaders() - } + h.parseRawHeaders() return h.contentLength } -func peekRawHeader(buf, key []byte) []byte { - n := bytes.Index(buf, key) - if n < 0 { - return nil - } - if n > 0 && buf[n-1] != '\n' { - return nil - } - buf = buf[n+len(key):] - if len(buf) == 0 { - return nil - } - if buf[0] != ':' { - return nil - } - n = 1 - for len(buf) > n && buf[n] == ' ' { - n++ - } - buf = buf[n:] - n = bytes.IndexByte(buf, '\n') - if n < 0 { - return nil - } - if n > 0 && buf[n-1] == '\r' { - n-- - } - for n > 0 && buf[n-1] == ' ' { - n-- - } - return buf[:n] -} - // SetContentLength sets Content-Length header value. // // Negative content-length sets 'Transfer-Encoding: chunked' header. @@ -1076,6 +1002,7 @@ func (h *ResponseHeader) parse(buf []byte) (int, error) { return 0, err } h.rawHeaders = rawHeaders + return m + n, nil } @@ -1089,6 +1016,7 @@ func (h *RequestHeader) parse(buf []byte) (int, error) { return 0, err } h.rawHeaders = rawHeaders + return m + n, nil } @@ -1160,6 +1088,28 @@ func (h *RequestHeader) parseFirstLine(buf []byte) (int, error) { return len(buf) - len(bNext), nil } +func hasRawHeader(buf, s []byte) bool { + n := bytes.Index(buf, s) + if n < 0 { + return false + } + if n > 0 && buf[n-1] != '\n' { + return false + } + n += len(s) + if n >= len(buf) { + return false + } + switch buf[n] { + case '\r': + return len(buf) > n+1 && buf[n+1] == '\n' + case '\n': + return true + default: + return false + } +} + func readRawHeaders(dst, buf []byte) ([]byte, int, error) { dst = dst[:0] n := bytes.IndexByte(buf, '\n') diff --git a/header_test.go b/header_test.go index 5381119..efc9de6 100644 --- a/header_test.go +++ b/header_test.go @@ -10,30 +10,30 @@ import ( "testing" ) -func TestPeekRawHeader(t *testing.T) { +func TestHasRawHeader(t *testing.T) { // empty header - testPeekRawHeader(t, "", "Foo-Bar", "") + testHasRawHeader(t, "", "Foo-Bar", false) // different case - testPeekRawHeader(t, "Content-Length: 3443\r\n", "content-length", "") + testHasRawHeader(t, "Content-Length: 3443\r\n", "content-length: 3443", false) // no trailing crlf - testPeekRawHeader(t, "Content-Length: 234", "Content-Length", "") + testHasRawHeader(t, "Content-Length: 234", "Content-Length: 234", false) // single header - testPeekRawHeader(t, "Content-Length: 12345\r\n", "Content-Length", "12345") + testHasRawHeader(t, "Content-Length: 12345\r\n", "Content-Length: 12345", true) // multiple headers - testPeekRawHeader(t, "Host: foobar\r\nContent-Length: 434\r\nFoo: bar\r\n\r\n", "Content-Length", "434") + testHasRawHeader(t, "Host: foobar\r\nContent-Length: 434\r\nFoo: bar\r\n\r\n", "Content-Length: 434", true) // lf without cr - testPeekRawHeader(t, "Foo: bar\nConnection: close\nAaa: bbb\ncc: ddd\n", "Connection", "close") + testHasRawHeader(t, "Foo: bar\nConnection: close\nAaa: bbb\ncc: ddd\n", "Connection: close", true) } -func testPeekRawHeader(t *testing.T, rawHeaders, key, expectedValue string) { - v := peekRawHeader([]byte(rawHeaders), []byte(key)) - if string(v) != expectedValue { - t.Fatalf("unexpected raw headers value %q. Expected %q. key %q, rawHeaders %q", v, expectedValue, key, rawHeaders) +func testHasRawHeader(t *testing.T, rawHeaders, s string, expectedValue bool) { + v := hasRawHeader([]byte(rawHeaders), []byte(s)) + if v != expectedValue { + t.Fatalf("unexpected raw headers value %v. Expected %v. s %q, rawHeaders %q", v, expectedValue, s, rawHeaders) } } @@ -811,7 +811,7 @@ func TestResponseHeaderReadSuccess(t *testing.T) { // duplicate content-length testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Length: 456\r\nContent-Type: foo/bar\r\nContent-Length: 321\r\n\r\n", - 200, 456, "foo/bar", "") + 200, 321, "foo/bar", "") // duplicate content-type testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Length: 234\r\nContent-Type: foo/bar\r\nContent-Type: baz/bar\r\n\r\n", @@ -936,7 +936,7 @@ func TestRequestHeaderReadSuccess(t *testing.T) { // post with duplicate content-length testRequestHeaderReadSuccess(t, h, "POST /xx HTTP/1.1\r\nHost: aa\r\nContent-Type: s\r\nContent-Length: 13\r\nContent-Length: 1\r\n\r\n", - 13, "/xx", "aa", "", "s", "") + 1, "/xx", "aa", "", "s", "") // non-post with content-type testRequestHeaderReadSuccess(t, h, "GET /aaa HTTP/1.1\r\nHost: bbb.com\r\nContent-Type: aaab\r\n\r\n", diff --git a/strings.go b/strings.go index bf8f10e..70bd9d9 100644 --- a/strings.go +++ b/strings.go @@ -23,6 +23,7 @@ var ( strPost = []byte("POST") strConnection = []byte("Connection") + strConnectionClose = []byte("Connection: close") strContentLength = []byte("Content-Length") strContentType = []byte("Content-Type") strDate = []byte("Date")